1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <ATen/core/operator_name.h> |
5 | #include <c10/macros/Export.h> |
6 | #include <c10/util/Optional.h> |
7 | #include <c10/util/SmallVector.h> |
8 | #include <c10/util/variant.h> |
9 | |
10 | #include <array> |
11 | #include <atomic> |
12 | #include <functional> |
13 | #include <memory> |
14 | |
15 | namespace c10 { |
16 | class TORCH_API OperatorHandle; |
17 | } |
18 | |
19 | namespace at { |
20 | |
21 | // Kind of record function scope; |
22 | enum class C10_API_ENUM RecordScope : uint8_t { |
23 | // c10/ATen ops, autograd nodes |
24 | FUNCTION = 0, |
25 | // Functions/nodes called from the autograd |
26 | BACKWARD_FUNCTION, |
27 | // TorchScript functions, methods |
28 | TORCHSCRIPT_FUNCTION, |
29 | // Kernel Function dtype Tag |
30 | KERNEL_FUNCTION_DTYPE, |
31 | // Torchbind custom class, |
32 | CUSTOM_CLASS, |
33 | // Generic Build Feature |
34 | BUILD_FEATURE, |
35 | // Kernel Function dtype Tag |
36 | LITE_INTERPRETER, |
37 | // User defined scope (e.g. with record_function()) |
38 | USER_SCOPE, |
39 | // Scopes for static runtime, a specialized TorchScript interpreter |
40 | STATIC_RUNTIME_OP, |
41 | STATIC_RUNTIME_MODEL, |
42 | NUM_SCOPES, // must be the last in the list |
43 | }; |
44 | |
45 | } // namespace at |
46 | |
47 | namespace std { |
48 | template <> |
49 | struct hash<at::RecordScope> { |
50 | size_t operator()(const at::RecordScope& sc) const { |
51 | return static_cast<std::size_t>(sc); |
52 | } |
53 | }; |
54 | } // namespace std |
55 | |
56 | namespace at { |
57 | |
58 | struct TORCH_API StringView { |
59 | StringView() : StringView(nullptr) {} |
60 | explicit StringView(const char* str_ptr) |
61 | : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {} |
62 | explicit StringView(std::string str) |
63 | : owned_str_ptr_(std::make_shared<std::string>(std::move(str))), |
64 | str_ptr_(owned_str_ptr_->c_str()) {} |
65 | |
66 | const char* str() const { |
67 | return str_ptr_; |
68 | } |
69 | |
70 | friend std::ostream& operator<<(std::ostream& os, const StringView& dt) { |
71 | os << dt.str(); |
72 | return os; |
73 | } |
74 | |
75 | friend bool operator==(const StringView& lhs, const StringView& rhs) { |
76 | return strcmp(lhs.str(), rhs.str()) == 0; |
77 | } |
78 | |
79 | friend bool operator!=(const StringView& lhs, const StringView& rhs) { |
80 | return !(lhs == rhs); |
81 | } |
82 | |
83 | private: |
84 | std::shared_ptr<std::string> owned_str_ptr_; |
85 | const char* str_ptr_; |
86 | }; |
87 | |
88 | // Soft limit on the number of callbacks to use; |
89 | constexpr std::size_t kSoftLimitCallbacks = 4; |
90 | |
91 | // An abstract base class for various observer contexts that can be attached to |
92 | // the RecordFunction. |
93 | struct ObserverContext { |
94 | virtual ~ObserverContext() = default; |
95 | |
96 | protected: |
97 | ObserverContext() {} |
98 | }; |
99 | |
100 | typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles; |
101 | typedef c10::SmallVector<std::unique_ptr<ObserverContext>, kSoftLimitCallbacks> |
102 | ObserverContextList; |
103 | typedef uint64_t RecordFunctionHandle; |
104 | struct RecordFunction; |
105 | |
106 | // |
107 | // PyTorch callbacks/observers API: |
108 | // |
109 | |
110 | /** |
111 | * RecordFunctionCallback represents a pair of callbacks to be used with |
112 | * RecordFunction, members: |
113 | * start, end - the callbacks to run when entering and exiting the scope; |
114 | * optionally, the start callback may return an ObserverContext which will |
115 | * be passed to the end callback, use appropriate constructor accordingly. |
116 | * needs_inputs - whether the callbacks need the inputs passed from the |
117 | * observed function/range; NOTE: passing the inputs incurs an additional |
118 | * overhead; sampling_probability - if not 1.0, then the callback is |
119 | * probabilistically sampled to run; NOTE: start and end callbacks always run as |
120 | * a pair and are sampled together; scopes - types of scopes to execute the |
121 | * callbacks on (see RecordScope); passing empty set means the callbacks will be |
122 | * executed for all possible scope types should_run - optional function that |
123 | * returns whether this callback should run; overwrites the effect of setting |
124 | * sampling_probability |
125 | */ |
126 | class TORCH_API RecordFunctionCallback { |
127 | public: |
128 | using StartCallback = |
129 | std::unique_ptr<ObserverContext> (*)(const RecordFunction&); |
130 | using EndCallback = void (*)(const RecordFunction&, ObserverContext*); |
131 | |
132 | // This interface supports observers that require passing an ObserverContext |
133 | // between start and end callbacks. |
134 | explicit RecordFunctionCallback( |
135 | StartCallback start, |
136 | EndCallback end = nullptr) |
137 | : start_(start), end_(end) { |
138 | scopes_.fill(true); |
139 | } |
140 | |
141 | RecordFunctionCallback& needsInputs(bool needs_inputs) { |
142 | needs_inputs_ = needs_inputs; |
143 | return *this; |
144 | } |
145 | |
146 | RecordFunctionCallback& needsOutputs(bool needs_outputs) { |
147 | needs_outputs_ = needs_outputs; |
148 | return *this; |
149 | } |
150 | |
151 | RecordFunctionCallback& needsIds(bool needs_ids) { |
152 | needs_ids_ = needs_ids; |
153 | return *this; |
154 | } |
155 | |
156 | RecordFunctionCallback& samplingProb(double sampling_prob) { |
157 | TORCH_CHECK( |
158 | sampling_prob >= 0.0 && sampling_prob <= 1.0, |
159 | "Invalid sampling probability" ); |
160 | sampling_prob_ = sampling_prob; |
161 | return *this; |
162 | } |
163 | |
164 | RecordFunctionCallback& scopes( |
165 | const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) { |
166 | if (!scopes.empty()) { |
167 | scopes_.fill(false); |
168 | for (auto sc : scopes) { |
169 | scopes_[static_cast<size_t>(sc)] = true; |
170 | } |
171 | } else { |
172 | scopes_.fill(true); |
173 | } |
174 | return *this; |
175 | } |
176 | |
177 | bool needsInputs() const { |
178 | return needs_inputs_; |
179 | } |
180 | |
181 | bool needsOutputs() const { |
182 | return needs_outputs_; |
183 | } |
184 | |
185 | bool needsIds() const { |
186 | return needs_ids_; |
187 | } |
188 | |
189 | double samplingProb() const { |
190 | return sampling_prob_; |
191 | } |
192 | |
193 | bool checkScope(RecordScope sc) const { |
194 | return scopes_[(size_t)sc]; |
195 | } |
196 | |
197 | StartCallback start() const { |
198 | return start_; |
199 | } |
200 | |
201 | EndCallback end() const { |
202 | return end_; |
203 | } |
204 | |
205 | private: |
206 | StartCallback start_; |
207 | EndCallback end_; |
208 | double sampling_prob_ = 1.0; |
209 | std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {}; |
210 | bool needs_inputs_ = false; |
211 | bool needs_outputs_ = false; |
212 | bool needs_ids_ = false; |
213 | }; |
214 | |
215 | // Notes: |
216 | // - two types of callbacks are provided: thread local and global |
217 | // - thread local callbacks are added/removed only for the given thread |
218 | // and are stored locally for each thread and separately from the list |
219 | // of the global callbacks |
220 | // - global callbacks are stored in a single per process list and are |
221 | // invoked by every RecordFunction, in addition to the thread local |
222 | // callbacks specific to the given thread |
223 | // - we allow the added callbacks to be sampled, by specifying a sampling |
224 | // probability for each callback pair, if the start callback is |
225 | // not picked to run, the corresponding end callback won't be called |
226 | // - a typical use case for the global callbacks is passive monitoring |
227 | // in the background (e.g. fleet-wide monitoring), without focusing on |
228 | // the specific piece of code |
229 | // - in contrast, thread local callbacks are enabled locally, on demand, |
230 | // for the specific piece of code (range) and are not sampled |
231 | // - a typical use case for thread local callbacks is profiler and code |
232 | // execution tracer |
233 | // - note, thread local callbacks are automatically propagated with |
234 | // ThreadLocalState across JIT continuations and async tasks (at::launch) |
235 | |
236 | typedef uint64_t CallbackHandle; |
237 | |
238 | constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0}; |
239 | |
240 | // It is unnecessary to use atomic operations for enabling |
241 | // thread-local function callbacks. Moreover, it prevents saving to |
242 | // ThreadLocalState because std::atomic is non-copyable. |
243 | struct RecordFunctionCallbacksEntry { |
244 | RecordFunctionCallbacksEntry(RecordFunctionCallback&& cb, CallbackHandle h) |
245 | : callback_(cb), handle_(h) {} |
246 | |
247 | RecordFunctionCallback callback_; |
248 | bool enabled_{true}; |
249 | CallbackHandle handle_; |
250 | }; |
251 | |
252 | // Holds pairs (callbacks, unique_id) |
253 | using RecordFunctionCallbacks = std::vector<RecordFunctionCallbacksEntry>; |
254 | |
255 | // Generated by the callback managers to determine which functions to run. |
256 | struct StepCallbacks { |
257 | StepCallbacks() = default; |
258 | StepCallbacks(uint64_t thread_id, RecordScope scope) |
259 | : thread_id_{thread_id}, scope_{scope} {} |
260 | |
261 | bool empty() const { |
262 | return callbacks_.empty(); |
263 | } |
264 | |
265 | struct StartEndPair { |
266 | RecordFunctionCallback::StartCallback start_; |
267 | RecordFunctionCallback::EndCallback end_; |
268 | }; |
269 | |
270 | using StartEndPairs = c10::SmallVector<StartEndPair, kSoftLimitCallbacks>; |
271 | |
272 | StartEndPairs callbacks_; |
273 | uint64_t thread_id_{0}; |
274 | RecordScope scope_{RecordScope::FUNCTION}; |
275 | bool needs_inputs_{false}; |
276 | bool needs_outputs_{false}; |
277 | bool needs_ids_{false}; |
278 | }; |
279 | |
280 | struct TORCH_API RecordFunction { |
281 | // Default constructor is used with before function called afterwards: |
282 | // scope - record scope that this function tracks |
283 | // pre_sampled - whether this RecordFunction was already pre-sampled with |
284 | // kLowProb probability |
285 | explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); |
286 | explicit RecordFunction(StepCallbacks&& step_callbacks); |
287 | |
288 | template <typename F> |
289 | void before( |
290 | F fn, |
291 | c10::ArrayRef<const c10::IValue> args, |
292 | int64_t current_sequence_nr = -1) { |
293 | if (!isActive()) { |
294 | return; |
295 | } |
296 | inputs_ = args; |
297 | before(fn, current_sequence_nr); |
298 | } |
299 | |
300 | template <typename F> |
301 | void before( |
302 | F fn, |
303 | const std::vector<IValue>* args, |
304 | int64_t current_sequence_nr = -1) { |
305 | before( |
306 | std::move(fn), |
307 | c10::ArrayRef<const c10::IValue>(args->data(), args->size()), |
308 | current_sequence_nr); |
309 | } |
310 | |
311 | // Destructor calls end callbacks |
312 | virtual ~RecordFunction(); |
313 | |
314 | RecordFunction(const RecordFunction&) = delete; |
315 | RecordFunction& operator=(const RecordFunction&) = delete; |
316 | |
317 | const char* name() const; |
318 | |
319 | int64_t seqNr() const { |
320 | return sequence_nr_; |
321 | } |
322 | |
323 | c10::ArrayRef<const IValue> inputs() const { |
324 | #ifndef NDEBUG |
325 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
326 | inputs_valid_, "Called inputs() outside RecordFunction start callback" ); |
327 | #endif |
328 | return inputs_; |
329 | } |
330 | |
331 | const std::vector<c10::IValue>& outputs() const { |
332 | return outputs_; |
333 | } |
334 | |
335 | void setOutputs(std::vector<c10::IValue>&& outputs) { |
336 | outputs_ = std::move(outputs); |
337 | } |
338 | |
339 | void setOutputs(c10::ArrayRef<c10::IValue> outputs) { |
340 | outputs_ = outputs.vec(); |
341 | } |
342 | |
343 | size_t num_inputs() const; |
344 | size_t num_outputs() const; |
345 | |
346 | // Retrieves the thread_id that this RecordFunction ran start callbacks with. |
347 | // Useful for writing thread safe end callbacks that may be potentially |
348 | // executed in a different thread (async ops) |
349 | uint64_t threadId() const { |
350 | return step_callbacks_.thread_id_; |
351 | } |
352 | |
353 | // For backward functions - thread id of the corresponding forward function, |
354 | // or zero otherwise; |
355 | // used alongside with sequence number to correlate backward functions with |
356 | // the forward ones |
357 | uint64_t forwardThreadId() const { |
358 | return fwd_thread_id_; |
359 | } |
360 | |
361 | void setForwardThreadId(uint64_t thread_id) { |
362 | fwd_thread_id_ = thread_id; |
363 | } |
364 | |
365 | RecordScope scope() const { |
366 | return step_callbacks_.scope_; |
367 | } |
368 | |
369 | // Returns logical thread_id for the current thread |
370 | static uint64_t currentThreadId(); |
371 | |
372 | // Internal functions, do not use directly; |
373 | // used in python's context manager |
374 | |
375 | // before functions initialize RecordFunction members and call |
376 | // start callbacks |
377 | using schema_ref_t = std::reference_wrapper<const c10::FunctionSchema>; |
378 | void before(const char* name, int64_t sequence_nr = -1); |
379 | void before(std::string name, int64_t sequence_nr = -1); |
380 | void before(schema_ref_t schema, int64_t sequence_nr = -1); |
381 | |
382 | // Sets node ID for distributed profiling |
383 | static void setDefaultNodeId(int64_t defaultNodeId); |
384 | // Gets node ID for distributed profiling |
385 | static int64_t getDefaultNodeId(); |
386 | |
387 | // Calls end callbacks. After end(), accessors will no longer provide useful |
388 | // results. |
389 | void end(); |
390 | |
391 | // Internal-only, used only force async event for distributed events |
392 | // profiling. |
393 | void _setAsync(); |
394 | |
395 | // Returns whether this RecordFunction corresponds to an async event orn ot. |
396 | bool isAsync() const; |
397 | |
398 | // Internal-only, used to denote out variant used for Static Runtime execution |
399 | void _setStaticRuntimeOutVariant(); |
400 | bool isStaticRuntimeOutVariant() const; |
401 | |
402 | RecordFunctionHandle handle() const { |
403 | return handle_; |
404 | } |
405 | |
406 | c10::optional<OperatorName> operator_name() const; |
407 | |
408 | // This method returns a copy of the FunctionSchema and can be expensive. |
409 | c10::optional<FunctionSchema> operator_schema() const; |
410 | |
411 | void setHandle(RecordFunctionHandle handle) { |
412 | handle_ = handle; |
413 | } |
414 | |
415 | // Whether this RecordFunction runs any callbacks. |
416 | bool isActive() const { |
417 | return !step_callbacks_.empty(); |
418 | } |
419 | |
420 | bool needsInputs() const { |
421 | return step_callbacks_.needs_inputs_; |
422 | } |
423 | |
424 | bool needsOutputs() const { |
425 | return step_callbacks_.needs_outputs_; |
426 | } |
427 | |
428 | int64_t debugHandle() const { |
429 | return debug_handle_; |
430 | } |
431 | |
432 | void setDebugHandle(int64_t debug_handle) { |
433 | debug_handle_ = debug_handle; |
434 | } |
435 | |
436 | void invalidateInputs() { |
437 | #ifndef NDEBUG |
438 | inputs_valid_ = false; |
439 | #endif |
440 | } |
441 | |
442 | private: |
443 | void runStartCallbacks(); |
444 | |
445 | StepCallbacks step_callbacks_; |
446 | |
447 | // In cases when RecordFunction might be active but we chose not to |
448 | // use the observers (e.g. operator is not observed), this boolean |
449 | // flag is used to check whether the start callbacks were called |
450 | bool called_start_callbacks_ = false; |
451 | |
452 | #ifndef NDEBUG |
453 | bool inputs_valid_ = false; |
454 | #endif |
455 | |
456 | // Stores various ObserverContext objects with event metadata for callbacks. |
457 | ObserverContextList ctx_; |
458 | |
459 | c10::variant<std::string, schema_ref_t> fn_; |
460 | |
461 | int64_t sequence_nr_ = -1; |
462 | c10::ArrayRef<const IValue> inputs_; |
463 | std::vector<c10::IValue> outputs_; |
464 | |
465 | // For backward functions - thread id of the the forward function |
466 | uint64_t fwd_thread_id_ = 0; |
467 | |
468 | // Unique id for this RecordFunction, used in callbacks to track start |
469 | // and end of ranges |
470 | RecordFunctionHandle handle_{0}; |
471 | |
472 | // Whether this record_function corresponds to an async event or not. Async |
473 | // events can complete in different threads or follow a future-like pattern |
474 | // of use. |
475 | bool is_async_{false}; |
476 | |
477 | // Debug handles are used for lazy annotation of module hierarchy |
478 | // and callstack. |
479 | // This is specifically is useful for mobile runtime, where generated |
480 | // debug handles can be lazily symbolicated using debug information |
481 | int64_t debug_handle_{-1}; |
482 | |
483 | // Whether this RecordFunction is used for an out variant run with |
484 | // Static Runtime |
485 | bool is_static_runtime_out_variant_{false}; |
486 | }; |
487 | |
488 | TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); |
489 | |
490 | TORCH_API c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty( |
491 | RecordScope scope); |
492 | |
493 | namespace detail { |
494 | template <typename Inputs, typename F, typename... Args> |
495 | void record_function_with_scope( |
496 | RecordFunction& guard, |
497 | F fn, |
498 | const Inputs& inputs, |
499 | Args&&... args) { |
500 | if (guard.needsInputs()) { |
501 | guard.before( |
502 | fn, |
503 | c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()), |
504 | std::forward<Args>(args)...); |
505 | } else { |
506 | guard.before(fn, std::forward<Args>(args)...); |
507 | } |
508 | } |
509 | |
510 | template <typename Inputs, typename F, typename... Args> |
511 | void record_function_with_scope_and_debug_handle( |
512 | RecordFunction& guard, |
513 | F fn, |
514 | int64_t debug_handle, |
515 | const Inputs& inputs, |
516 | Args&&... args) { |
517 | guard.setDebugHandle(debug_handle); |
518 | if (guard.needsInputs()) { |
519 | guard.before( |
520 | fn, |
521 | c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()), |
522 | std::forward<Args>(args)...); |
523 | } else { |
524 | guard.before(fn, std::forward<Args>(args)...); |
525 | } |
526 | } |
527 | |
528 | template <typename F, typename... Args> |
529 | void record_function_with_scope( |
530 | RecordFunction& guard, |
531 | F fn, |
532 | c10::ArrayRef<const c10::IValue> inputs, |
533 | Args&&... args) { |
534 | return record_function_with_scope< |
535 | c10::ArrayRef<const c10::IValue>, |
536 | F, |
537 | Args...>(guard, std::move(fn), inputs, std::forward<Args>(args)...); |
538 | } |
539 | |
540 | template <typename F, typename... Args> |
541 | void record_function_with_scope_and_debug_handle( |
542 | RecordFunction& guard, |
543 | F fn, |
544 | int64_t debug_handle, |
545 | c10::ArrayRef<const c10::IValue> inputs, |
546 | Args&&... args) { |
547 | return record_function_with_scope_and_debug_handle< |
548 | c10::ArrayRef<const c10::IValue>, |
549 | F, |
550 | Args...>( |
551 | guard, std::move(fn), debug_handle, inputs, std::forward<Args>(args)...); |
552 | } |
553 | |
554 | } // namespace detail |
555 | |
556 | // optional argument - function's seq_no |
557 | #define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \ |
558 | at::RecordFunction guard(scope); \ |
559 | if (guard.isActive()) { \ |
560 | ::at::detail::record_function_with_scope( \ |
561 | guard, fn, inputs, ##__VA_ARGS__); \ |
562 | } |
563 | |
564 | #define RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ |
565 | scope, fn, inputs, outputs, ...) \ |
566 | at::RecordFunction guard(scope); \ |
567 | if (guard.isActive()) { \ |
568 | if (guard.needsInputs()) { \ |
569 | guard.before(fn, inputs, ##__VA_ARGS__); \ |
570 | } else { \ |
571 | guard.before(fn, ##__VA_ARGS__); \ |
572 | } \ |
573 | if (guard.needsOutputs()) { \ |
574 | guard.setOutputs(outputs); \ |
575 | } \ |
576 | } |
577 | |
578 | #define RECORD_FUNCTION(fn, inputs, ...) \ |
579 | RECORD_FUNCTION_WITH_SCOPE( \ |
580 | at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__) |
581 | |
582 | #define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \ |
583 | RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs) |
584 | |
585 | #define RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(fn, inputs, outputs, ...) \ |
586 | RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ |
587 | at::RecordScope::FUNCTION, fn, inputs, outputs, ##__VA_ARGS__) |
588 | |
589 | // Custom user scopes in C++; similar to Python's 'with record_function("..."):' |
590 | #define RECORD_USER_SCOPE(fn) \ |
591 | RECORD_FUNCTION_WITH_SCOPE( \ |
592 | at::RecordScope::USER_SCOPE, fn, c10::ArrayRef<const c10::IValue>{}) |
593 | |
594 | // RECORD_USER_SCOPE with inputs |
595 | #define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ |
596 | RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs) |
597 | |
598 | // Helper macro to pass in debug handle that is used to |
599 | // post process events |
600 | #define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ |
601 | scope, fn, debug_handle, inputs, ...) \ |
602 | at::RecordFunction guard(scope); \ |
603 | if (guard.isActive()) { \ |
604 | ::at::detail::record_function_with_scope_and_debug_handle( \ |
605 | guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ |
606 | } |
607 | |
608 | // Helper macros to record LITE INTERPETER scope events with debug handles |
609 | #define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ |
610 | fn, debug_handle, inputs) \ |
611 | RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ |
612 | at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs) |
613 | |
614 | // Bookend to the RECORD_FUNCTION macros. Use this after the kernel |
615 | // launch to let the profiler bind the outputs to the op that produced |
616 | // them. Note that guard is declared by RECORD_FUNCTION so this macro |
617 | // needs to be called from the same scope as RECORD_FUNCTION |
618 | #define RECORD_OUTPUTS(outputs) \ |
619 | if (guard.needsOutputs()) { \ |
620 | guard.setOutputs( \ |
621 | std::vector<c10::IValue>(outputs.begin(), outputs.end())); \ |
622 | } |
623 | |
624 | /** |
625 | * addThreadLocalCallback adds a thread local callback to run with |
626 | * RecordFunction, returns handle to use with removeThreadLocalCallback |
627 | */ |
628 | TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb); |
629 | |
630 | /** |
631 | * hasThreadLocalCallbacks returns whether there're callbacks registered |
632 | * with addThreadLocalCallback |
633 | */ |
634 | TORCH_API bool hasThreadLocalCallbacks(); |
635 | |
636 | /** |
637 | * clearThreadLocalCallbacks removes all thread local callbacks |
638 | */ |
639 | TORCH_API void clearThreadLocalCallbacks(); |
640 | |
641 | /** |
642 | * addGlobalCallback adds a global callback to run with RecordFunction: |
643 | * |
644 | * only during the program initialization |
645 | */ |
646 | TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb); |
647 | |
648 | /** |
649 | * removeCallback removes a callback given the handle returned by |
650 | * addThreadLocalCallback or addGlobalCallback; |
651 | * |
652 | * no other code can run simultaneously |
653 | */ |
654 | TORCH_API void removeCallback(CallbackHandle handle); |
655 | |
656 | /** |
657 | * Prevent the given callback from executing. If handle is invalid, |
658 | * does nothing. |
659 | */ |
660 | TORCH_API void disableCallback(CallbackHandle handle); |
661 | |
662 | /** |
663 | * Allow the given callback, previously disabled with disableCallback, to |
664 | * execute again. If handle is invalid, does nothing. |
665 | */ |
666 | TORCH_API void reenableCallback(CallbackHandle handle); |
667 | |
668 | /** |
669 | * hasGlobalCallbacks returns whether there're global callbacks |
670 | * registered with pushGlobalCallback |
671 | */ |
672 | TORCH_API bool hasGlobalCallbacks(); |
673 | |
674 | /** |
675 | * clearGlobalCallbacks removes all global callbacks |
676 | */ |
677 | TORCH_API void clearGlobalCallbacks(); |
678 | |
679 | // for both thread local and global callbacks |
680 | TORCH_API bool hasCallbacks(); |
681 | TORCH_API void clearCallbacks(); |
682 | |
683 | /** |
684 | * enableRecordFunction enables RecordFunction thread locally |
685 | */ |
686 | TORCH_API void enableRecordFunction(bool enable = true); |
687 | |
688 | /** |
689 | * isRecordFunctionEnabled returns whether RecordFunction |
690 | * is enabled thread locally |
691 | */ |
692 | TORCH_API bool isRecordFunctionEnabled(); |
693 | |
694 | class TORCH_API RecordFunctionGuard { |
695 | public: |
696 | explicit RecordFunctionGuard(bool is_enabled = true) |
697 | : prev_value_(isRecordFunctionEnabled()) { |
698 | enableRecordFunction(is_enabled); |
699 | } |
700 | |
701 | virtual ~RecordFunctionGuard() { |
702 | enableRecordFunction(prev_value_); |
703 | } |
704 | |
705 | private: |
706 | bool prev_value_ = false; |
707 | }; |
708 | |
709 | class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard { |
710 | public: |
711 | DisableRecordFunctionGuard() : RecordFunctionGuard(false) {} |
712 | ~DisableRecordFunctionGuard() override = default; |
713 | }; |
714 | |
715 | struct TORCH_API RecordFunctionTLS { |
716 | // Thread local vector of callbacks, holds pairs (callbacks, unique_id); |
717 | // must be sorted in increasing handles order |
718 | RecordFunctionCallbacks sorted_tls_callbacks_; |
719 | |
720 | bool tls_record_function_enabled_ = true; |
721 | }; |
722 | |
723 | TORCH_API const RecordFunctionTLS& get_record_function_tls_(); |
724 | |
725 | TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); |
726 | |
727 | TORCH_API void set_record_function_seed_for_testing(uint32_t seed); |
728 | |
729 | } // namespace at |
730 | |