1#include <ATen/ThreadLocalState.h>
2#include <ATen/cpp_custom_type_hack.h>
3#include <ATen/record_function.h>
4#include <torch/csrc/autograd/record_function_ops.h>
5
6#include <torch/csrc/jit/runtime/operator.h>
7#include <torch/library.h>
8
9namespace caffe2 {
10// Required for cpp_custom_type_hack to work
11// NOLINTNEXTLINE(bugprone-exception-escape)
12CAFFE_KNOWN_TYPE(at::RecordFunction);
13} // namespace caffe2
14
15namespace torch {
16namespace autograd {
17namespace profiler {
18
19// Creates a new profiling scope using RecordFunction and invokes its starting
20// callbacks.
21void record_function_enter(
22 const std::string& name,
23 const c10::optional<std::string>& args,
24 at::RecordFunction& rec) {
25 if (rec.isActive()) {
26 if (rec.needsInputs() && args.has_value()) {
27 rec.before(
28 name, c10::ArrayRef<const c10::IValue>{c10::IValue{args.value()}});
29 } else {
30 rec.before(name);
31 }
32 }
33}
34
35// Legacy signature using cpp_custom_type_hack
36at::Tensor record_function_enter_legacy(
37 const std::string& name,
38 const c10::optional<std::string>& args) {
39 auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
40 record_function_enter(name, args, *rec);
41 return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
42}
43
44// New signature using custom_class
45c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
46 const std::string& name,
47 const c10::optional<std::string>& args) {
48 auto rec =
49 c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
50 record_function_enter(name, args, rec->record);
51 return rec;
52}
53
54at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
55 auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
56 return rec;
57}
58
59// Ends the profiling scope created with record_function_enter.
60void record_function_exit(at::RecordFunction& rec) {
61 rec.end();
62}
63
64// Legacy signature using cpp_custom_type_hack
65void record_function_exit_legacy(const at::Tensor& handle) {
66 // We don't actually need to do anything with handle just need to persist the
67 // lifetime until now.
68 auto& rec = getRecordFunctionFromTensor(handle);
69 record_function_exit(rec);
70}
71
72// New signature using custom_class
73void record_function_exit_new(
74 const c10::intrusive_ptr<PythonRecordFunction>& record) {
75 record_function_exit(record->record);
76}
77
78template <typename Func>
79c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
80 Func get_record,
81 const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
82 // Profiling callback that ends the associated record_function
83 // and returns the value of the passed in future.
84 std::function<c10::IValue(c10::ivalue::Future&)> futureProfilingFunc =
85 [get_record = std::move(get_record)](c10::ivalue::Future& fut) {
86 auto& rec = get_record();
87 rec.end();
88 // Note: this future is returned to the user to ensure that a call to
89 // wait() ensures that profiling callbacks have ran. To ensure that this
90 // is transparent, we must make this future propagate the value of the
91 // RPC future. Use value() here instead of constValue() to ensure we
92 // propagate errors.
93 return fut.value();
94 };
95 // Define a future that completes after the profiling callbacks are run.
96 auto profiledFut = fut->then(
97 at::wrapPropagateTLSState(std::move(futureProfilingFunc)),
98 fut->elementType());
99 return profiledFut;
100}
101
102// Legacy signature using cpp_custom_type_hack
103c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
104 const at::Tensor& handle,
105 const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
106 return _call_end_callbacks_on_fut(
107 [handle]() -> at::RecordFunction& {
108 TORCH_INTERNAL_ASSERT(
109 handle.defined(),
110 "Undefined RecordFunction handle. This can happen if the handle is "
111 "not correctly persisted and is destroyed before the future is "
112 "realized.");
113
114 return getRecordFunctionFromTensor(handle);
115 },
116 fut);
117}
118
119// New signature using custom_class
120c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
121 const c10::intrusive_ptr<PythonRecordFunction>& record,
122 const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
123 return _call_end_callbacks_on_fut(
124 [record]() -> at::RecordFunction& { return record->record; }, fut);
125}
126
127// Internal only, do not use directly, use Python's record_function()
128TORCH_LIBRARY_FRAGMENT(profiler, m) {
129 m.class_<PythonRecordFunction>("_RecordFunction");
130
131 m.def(
132 "_record_function_enter(str name, str? args=None) -> Tensor",
133 &record_function_enter_legacy);
134 m.def(
135 "_record_function_enter_new(str name, str? args=None) -> "
136 "__torch__.torch.classes.profiler._RecordFunction",
137 &record_function_enter_new);
138 m.def("_record_function_exit", &record_function_exit_legacy);
139 m.def("_record_function_exit._RecordFunction", &record_function_exit_new);
140
141 torch::jit::registerOperator(torch::jit::Operator(
142 "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
143 [](jit::Stack& stack) {
144 // Pop inputs, which should be a future and a tensor
145 auto fut = jit::pop(stack).toFuture();
146 auto tensor = jit::pop(stack).toTensor();
147 auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
148 // return future that completes when profiling callbacks have run.
149 jit::push(stack, std::move(profiledFut));
150 },
151 c10::AliasAnalysisKind::FROM_SCHEMA));
152 torch::jit::registerOperator(torch::jit::Operator(
153 "profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
154 "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
155 [](c10::Stack& stack) {
156 // Pop inputs, which should be a future and a PythonRecordFunction
157 auto fut = torch::jit::pop(stack).toFuture();
158 auto tensor =
159 torch::jit::pop(stack).toCustomClass<PythonRecordFunction>();
160 auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut);
161 // return future that completes when profiling callbacks have run.
162 torch::jit::push(stack, std::move(profiledFut));
163 },
164 c10::AliasAnalysisKind::FROM_SCHEMA));
165}
166
167} // namespace profiler
168} // namespace autograd
169} // namespace torch
170