1 | #include <ATen/ThreadLocalState.h> |
2 | #include <ATen/cpp_custom_type_hack.h> |
3 | #include <ATen/record_function.h> |
4 | #include <torch/csrc/autograd/record_function_ops.h> |
5 | |
6 | #include <torch/csrc/jit/runtime/operator.h> |
7 | #include <torch/library.h> |
8 | |
9 | namespace caffe2 { |
10 | // Required for cpp_custom_type_hack to work |
11 | // NOLINTNEXTLINE(bugprone-exception-escape) |
12 | CAFFE_KNOWN_TYPE(at::RecordFunction); |
13 | } // namespace caffe2 |
14 | |
15 | namespace torch { |
16 | namespace autograd { |
17 | namespace profiler { |
18 | |
19 | // Creates a new profiling scope using RecordFunction and invokes its starting |
20 | // callbacks. |
21 | void record_function_enter( |
22 | const std::string& name, |
23 | const c10::optional<std::string>& args, |
24 | at::RecordFunction& rec) { |
25 | if (rec.isActive()) { |
26 | if (rec.needsInputs() && args.has_value()) { |
27 | rec.before( |
28 | name, c10::ArrayRef<const c10::IValue>{c10::IValue{args.value()}}); |
29 | } else { |
30 | rec.before(name); |
31 | } |
32 | } |
33 | } |
34 | |
35 | // Legacy signature using cpp_custom_type_hack |
36 | at::Tensor record_function_enter_legacy( |
37 | const std::string& name, |
38 | const c10::optional<std::string>& args) { |
39 | auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE); |
40 | record_function_enter(name, args, *rec); |
41 | return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions()); |
42 | } |
43 | |
44 | // New signature using custom_class |
45 | c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new( |
46 | const std::string& name, |
47 | const c10::optional<std::string>& args) { |
48 | auto rec = |
49 | c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE); |
50 | record_function_enter(name, args, rec->record); |
51 | return rec; |
52 | } |
53 | |
54 | at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) { |
55 | auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle); |
56 | return rec; |
57 | } |
58 | |
59 | // Ends the profiling scope created with record_function_enter. |
60 | void record_function_exit(at::RecordFunction& rec) { |
61 | rec.end(); |
62 | } |
63 | |
64 | // Legacy signature using cpp_custom_type_hack |
65 | void record_function_exit_legacy(const at::Tensor& handle) { |
66 | // We don't actually need to do anything with handle just need to persist the |
67 | // lifetime until now. |
68 | auto& rec = getRecordFunctionFromTensor(handle); |
69 | record_function_exit(rec); |
70 | } |
71 | |
72 | // New signature using custom_class |
73 | void record_function_exit_new( |
74 | const c10::intrusive_ptr<PythonRecordFunction>& record) { |
75 | record_function_exit(record->record); |
76 | } |
77 | |
78 | template <typename Func> |
79 | c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut( |
80 | Func get_record, |
81 | const c10::intrusive_ptr<c10::ivalue::Future>& fut) { |
82 | // Profiling callback that ends the associated record_function |
83 | // and returns the value of the passed in future. |
84 | std::function<c10::IValue(c10::ivalue::Future&)> futureProfilingFunc = |
85 | [get_record = std::move(get_record)](c10::ivalue::Future& fut) { |
86 | auto& rec = get_record(); |
87 | rec.end(); |
88 | // Note: this future is returned to the user to ensure that a call to |
89 | // wait() ensures that profiling callbacks have ran. To ensure that this |
90 | // is transparent, we must make this future propagate the value of the |
91 | // RPC future. Use value() here instead of constValue() to ensure we |
92 | // propagate errors. |
93 | return fut.value(); |
94 | }; |
95 | // Define a future that completes after the profiling callbacks are run. |
96 | auto profiledFut = fut->then( |
97 | at::wrapPropagateTLSState(std::move(futureProfilingFunc)), |
98 | fut->elementType()); |
99 | return profiledFut; |
100 | } |
101 | |
102 | // Legacy signature using cpp_custom_type_hack |
103 | c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy( |
104 | const at::Tensor& handle, |
105 | const c10::intrusive_ptr<c10::ivalue::Future>& fut) { |
106 | return _call_end_callbacks_on_fut( |
107 | [handle]() -> at::RecordFunction& { |
108 | TORCH_INTERNAL_ASSERT( |
109 | handle.defined(), |
110 | "Undefined RecordFunction handle. This can happen if the handle is " |
111 | "not correctly persisted and is destroyed before the future is " |
112 | "realized." ); |
113 | |
114 | return getRecordFunctionFromTensor(handle); |
115 | }, |
116 | fut); |
117 | } |
118 | |
119 | // New signature using custom_class |
120 | c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new( |
121 | const c10::intrusive_ptr<PythonRecordFunction>& record, |
122 | const c10::intrusive_ptr<c10::ivalue::Future>& fut) { |
123 | return _call_end_callbacks_on_fut( |
124 | [record]() -> at::RecordFunction& { return record->record; }, fut); |
125 | } |
126 | |
127 | // Internal only, do not use directly, use Python's record_function() |
128 | TORCH_LIBRARY_FRAGMENT(profiler, m) { |
129 | m.class_<PythonRecordFunction>("_RecordFunction" ); |
130 | |
131 | m.def( |
132 | "_record_function_enter(str name, str? args=None) -> Tensor" , |
133 | &record_function_enter_legacy); |
134 | m.def( |
135 | "_record_function_enter_new(str name, str? args=None) -> " |
136 | "__torch__.torch.classes.profiler._RecordFunction" , |
137 | &record_function_enter_new); |
138 | m.def("_record_function_exit" , &record_function_exit_legacy); |
139 | m.def("_record_function_exit._RecordFunction" , &record_function_exit_new); |
140 | |
141 | torch::jit::registerOperator(torch::jit::Operator( |
142 | "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)" , |
143 | [](jit::Stack& stack) { |
144 | // Pop inputs, which should be a future and a tensor |
145 | auto fut = jit::pop(stack).toFuture(); |
146 | auto tensor = jit::pop(stack).toTensor(); |
147 | auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut); |
148 | // return future that completes when profiling callbacks have run. |
149 | jit::push(stack, std::move(profiledFut)); |
150 | }, |
151 | c10::AliasAnalysisKind::FROM_SCHEMA)); |
152 | torch::jit::registerOperator(torch::jit::Operator( |
153 | "profiler::_call_end_callbacks_on_jit_fut._RecordFunction(" |
154 | "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)" , |
155 | [](c10::Stack& stack) { |
156 | // Pop inputs, which should be a future and a PythonRecordFunction |
157 | auto fut = torch::jit::pop(stack).toFuture(); |
158 | auto tensor = |
159 | torch::jit::pop(stack).toCustomClass<PythonRecordFunction>(); |
160 | auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut); |
161 | // return future that completes when profiling callbacks have run. |
162 | torch::jit::push(stack, std::move(profiledFut)); |
163 | }, |
164 | c10::AliasAnalysisKind::FROM_SCHEMA)); |
165 | } |
166 | |
167 | } // namespace profiler |
168 | } // namespace autograd |
169 | } // namespace torch |
170 | |