1 | #include <utility> |
2 | |
3 | #include <torch/csrc/utils/pybind.h> |
4 | #include <torch/csrc/utils/python_arg_parser.h> |
5 | #include <torch/csrc/utils/python_numbers.h> |
6 | #include <torch/csrc/utils/python_strings.h> |
7 | |
8 | #include <pybind11/chrono.h> |
9 | #include <pybind11/functional.h> |
10 | #include <pybind11/operators.h> |
11 | #include <pybind11/stl.h> |
12 | |
13 | #include <torch/csrc/monitor/counters.h> |
14 | #include <torch/csrc/monitor/events.h> |
15 | |
16 | namespace pybind11 { |
17 | namespace detail { |
18 | template <> |
19 | struct type_caster<torch::monitor::data_value_t> { |
20 | public: |
21 | PYBIND11_TYPE_CASTER(torch::monitor::data_value_t, _("data_value_t" )); |
22 | |
23 | // Python -> C++ |
24 | bool load(handle src, bool) { |
25 | PyObject* source = src.ptr(); |
26 | if (THPUtils_checkLong(source)) { |
27 | this->value = THPUtils_unpackLong(source); |
28 | } else if (THPUtils_checkDouble(source)) { |
29 | this->value = THPUtils_unpackDouble(source); |
30 | } else if (THPUtils_checkString(source)) { |
31 | this->value = THPUtils_unpackString(source); |
32 | } else if (PyBool_Check(source)) { |
33 | this->value = THPUtils_unpackBool(source); |
34 | } else { |
35 | return false; |
36 | } |
37 | return !PyErr_Occurred(); |
38 | } |
39 | |
40 | // C++ -> Python |
41 | static handle cast( |
42 | torch::monitor::data_value_t src, |
43 | return_value_policy /* policy */, |
44 | handle /* parent */) { |
45 | if (c10::holds_alternative<double>(src)) { |
46 | return PyFloat_FromDouble(c10::get<double>(src)); |
47 | } else if (c10::holds_alternative<int64_t>(src)) { |
48 | return THPUtils_packInt64(c10::get<int64_t>(src)); |
49 | } else if (c10::holds_alternative<bool>(src)) { |
50 | if (c10::get<bool>(src)) { |
51 | Py_RETURN_TRUE; |
52 | } else { |
53 | Py_RETURN_FALSE; |
54 | } |
55 | } else if (c10::holds_alternative<std::string>(src)) { |
56 | std::string str = c10::get<std::string>(src); |
57 | return THPUtils_packString(str); |
58 | } |
59 | throw std::runtime_error("unknown data_value_t type" ); |
60 | } |
61 | }; |
62 | } // namespace detail |
63 | } // namespace pybind11 |
64 | |
65 | namespace torch { |
66 | namespace monitor { |
67 | |
68 | namespace { |
69 | class PythonEventHandler : public EventHandler { |
70 | public: |
71 | explicit PythonEventHandler(std::function<void(const Event&)> handler) |
72 | : handler_(std::move(handler)) {} |
73 | |
74 | void handle(const Event& e) override { |
75 | handler_(e); |
76 | } |
77 | |
78 | private: |
79 | std::function<void(const Event&)> handler_; |
80 | }; |
81 | } // namespace |
82 | |
83 | void initMonitorBindings(PyObject* module) { |
84 | auto rootModule = py::handle(module).cast<py::module>(); |
85 | |
86 | auto m = rootModule.def_submodule("_monitor" ); |
87 | |
88 | py::enum_<Aggregation>( |
89 | m, |
90 | "Aggregation" , |
91 | R"DOC( |
92 | These are types of aggregations that can be used to accumulate stats. |
93 | )DOC" ) |
94 | .value( |
95 | "VALUE" , |
96 | Aggregation::NONE, |
97 | R"DOC( |
98 | VALUE returns the last value to be added. |
99 | )DOC" ) |
100 | .value( |
101 | "MEAN" , |
102 | Aggregation::MEAN, |
103 | R"DOC( |
104 | MEAN computes the arithmetic mean of all the added values. |
105 | )DOC" ) |
106 | .value( |
107 | "COUNT" , |
108 | Aggregation::COUNT, |
109 | R"DOC( |
110 | COUNT returns the total number of added values. |
111 | )DOC" ) |
112 | .value( |
113 | "SUM" , |
114 | Aggregation::SUM, |
115 | R"DOC( |
116 | SUM returns the sum of the added values. |
117 | )DOC" ) |
118 | .value( |
119 | "MAX" , |
120 | Aggregation::MAX, |
121 | R"DOC( |
122 | MAX returns the max of the added values. |
123 | )DOC" ) |
124 | .value( |
125 | "MIN" , |
126 | Aggregation::MIN, |
127 | R"DOC( |
128 | MIN returns the min of the added values. |
129 | )DOC" ) |
130 | .export_values(); |
131 | |
132 | py::class_<Stat<double>>( |
133 | m, |
134 | "Stat" , |
135 | R"DOC( |
136 | Stat is used to compute summary statistics in a performant way over |
137 | fixed intervals. Stat logs the statistics as an Event once every |
138 | ``window_size`` duration. When the window closes the stats are logged |
139 | via the event handlers as a ``torch.monitor.Stat`` event. |
140 | |
141 | ``window_size`` should be set to something relatively high to avoid a |
142 | huge number of events being logged. Ex: 60s. Stat uses millisecond |
143 | precision. |
144 | |
145 | If ``max_samples`` is set, the stat will cap the number of samples per |
146 | window by discarding `add` calls once ``max_samples`` adds have |
147 | occurred. If it's not set, all ``add`` calls during the window will be |
148 | included. This is an optional field to make aggregations more directly |
149 | comparable across windows when the number of samples might vary. |
150 | |
151 | When the Stat is destructed it will log any remaining data even if the |
152 | window hasn't elapsed. |
153 | )DOC" ) |
154 | .def( |
155 | py::init< |
156 | std::string, |
157 | std::vector<Aggregation>, |
158 | std::chrono::milliseconds, |
159 | int64_t>(), |
160 | py::arg("name" ), |
161 | py::arg("aggregations" ), |
162 | py::arg("window_size" ), |
163 | py::arg("max_samples" ) = std::numeric_limits<int64_t>::max(), |
164 | R"DOC( |
165 | Constructs the ``Stat``. |
166 | )DOC" ) |
167 | .def( |
168 | "add" , |
169 | &Stat<double>::add, |
170 | py::arg("v" ), |
171 | R"DOC( |
172 | Adds a value to the stat to be aggregated according to the |
173 | configured stat type and aggregations. |
174 | )DOC" ) |
175 | .def( |
176 | "get" , |
177 | &Stat<double>::get, |
178 | R"DOC( |
179 | Returns the current value of the stat, primarily for testing |
180 | purposes. If the stat has logged and no additional values have been |
181 | added this will be zero. |
182 | )DOC" ) |
183 | .def_property_readonly( |
184 | "name" , |
185 | &Stat<double>::name, |
186 | R"DOC( |
187 | The name of the stat that was set during creation. |
188 | )DOC" ) |
189 | .def_property_readonly( |
190 | "count" , |
191 | &Stat<double>::count, |
192 | R"DOC( |
193 | Number of data points that have currently been collected. Resets |
194 | once the event has been logged. |
195 | )DOC" ); |
196 | |
197 | py::class_<Event>( |
198 | m, |
199 | "Event" , |
200 | R"DOC( |
201 | Event represents a specific typed event to be logged. This can represent |
202 | high-level data points such as loss or accuracy per epoch or more |
203 | low-level aggregations such as through the Stats provided through this |
204 | library. |
205 | |
206 | All Events of the same type should have the same name so downstream |
207 | handlers can correctly process them. |
208 | )DOC" ) |
209 | .def( |
210 | py::init([](const std::string& name, |
211 | std::chrono::system_clock::time_point timestamp, |
212 | std::unordered_map<std::string, data_value_t> data) { |
213 | Event e; |
214 | e.name = name; |
215 | e.timestamp = timestamp; |
216 | e.data = data; |
217 | return e; |
218 | }), |
219 | py::arg("name" ), |
220 | py::arg("timestamp" ), |
221 | py::arg("data" ), |
222 | R"DOC( |
223 | Constructs the ``Event``. |
224 | )DOC" ) |
225 | .def_readwrite( |
226 | "name" , |
227 | &Event::name, |
228 | R"DOC( |
229 | The name of the ``Event``. |
230 | )DOC" ) |
231 | .def_readwrite( |
232 | "timestamp" , |
233 | &Event::timestamp, |
234 | R"DOC( |
235 | The timestamp when the ``Event`` happened. |
236 | )DOC" ) |
237 | .def_readwrite( |
238 | "data" , |
239 | &Event::data, |
240 | R"DOC( |
241 | The structured data contained within the ``Event``. |
242 | )DOC" ); |
243 | |
244 | m.def( |
245 | "log_event" , |
246 | &logEvent, |
247 | py::arg("event" ), |
248 | R"DOC( |
249 | log_event logs the specified event to all of the registered event |
250 | handlers. It's up to the event handlers to log the event out to the |
251 | corresponding event sink. |
252 | |
253 | If there are no event handlers registered this method is a no-op. |
254 | )DOC" ); |
255 | |
256 | py::class_<data_value_t> dataClass( |
257 | m, |
258 | "data_value_t" , |
259 | R"DOC( |
260 | data_value_t is one of ``str``, ``float``, ``int``, ``bool``. |
261 | )DOC" ); |
262 | |
263 | py::implicitly_convertible<std::string, data_value_t>(); |
264 | py::implicitly_convertible<double, data_value_t>(); |
265 | py::implicitly_convertible<int64_t, data_value_t>(); |
266 | py::implicitly_convertible<bool, data_value_t>(); |
267 | |
268 | py::class_<PythonEventHandler, std::shared_ptr<PythonEventHandler>> |
269 | eventHandlerClass(m, "EventHandlerHandle" , R"DOC( |
270 | EventHandlerHandle is a wrapper type returned by |
271 | ``register_event_handler`` used to unregister the handler via |
272 | ``unregister_event_handler``. This cannot be directly initialized. |
273 | )DOC" ); |
274 | m.def( |
275 | "register_event_handler" , |
276 | [](std::function<void(const Event&)> f) { |
277 | auto handler = std::make_shared<PythonEventHandler>(f); |
278 | registerEventHandler(handler); |
279 | return handler; |
280 | }, |
281 | py::arg("callback" ), |
282 | R"DOC( |
283 | register_event_handler registers a callback to be called whenever an |
284 | event is logged via ``log_event``. These handlers should avoid blocking |
285 | the main thread since that may interfere with training as they run |
286 | during the ``log_event`` call. |
287 | )DOC" ); |
288 | m.def( |
289 | "unregister_event_handler" , |
290 | [](std::shared_ptr<PythonEventHandler> handler) { |
291 | unregisterEventHandler(handler); |
292 | }, |
293 | py::arg("handler" ), |
294 | R"DOC( |
295 | unregister_event_handler unregisters the ``EventHandlerHandle`` returned |
296 | after calling ``register_event_handler``. After this returns the event |
297 | handler will no longer receive events. |
298 | )DOC" ); |
299 | } |
300 | |
301 | } // namespace monitor |
302 | } // namespace torch |
303 | |