1 | #pragma once |
---|---|
2 | #include <ATen/record_function.h> |
3 | #include <c10/util/Optional.h> |
4 | #include <torch/custom_class.h> |
5 | |
6 | namespace torch { |
7 | namespace autograd { |
8 | namespace profiler { |
9 | |
10 | struct PythonRecordFunction : public torch::CustomClassHolder { |
11 | at::RecordFunction record; |
12 | |
13 | explicit PythonRecordFunction( |
14 | at::RecordScope scope = at::RecordScope::FUNCTION) |
15 | : record(scope) {} |
16 | }; |
17 | |
18 | // Creates a new profiling scope using RecordFunction and invokes its starting |
19 | // callbacks. |
20 | TORCH_API c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new( |
21 | const std::string& name, |
22 | const c10::optional<std::string>& args = c10::nullopt); |
23 | |
24 | // Schedules RecordFunction's end callbacks to be run on completion of a future. |
25 | TORCH_API c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new( |
26 | const c10::intrusive_ptr<PythonRecordFunction>& record, |
27 | const c10::intrusive_ptr<c10::ivalue::Future>& fut); |
28 | |
29 | } // namespace profiler |
30 | } // namespace autograd |
31 | } // namespace torch |
32 |