1 | #ifdef _WIN32 |
2 | #ifndef WIN32_LEAN_AND_MEAN |
3 | #define WIN32_LEAN_AND_MEAN |
4 | #endif |
5 | #include <windows.h> |
6 | |
7 | #include <processthreadsapi.h> |
8 | #else |
9 | #include <unistd.h> |
10 | #endif // _WIN32 |
11 | |
12 | #include <fmt/format.h> |
13 | #include <algorithm> |
14 | #include <chrono> |
15 | #include <cmath> |
16 | #include <fstream> |
17 | #include <iomanip> |
18 | #include <map> |
19 | #include <mutex> |
20 | #include <sstream> |
21 | #include <stack> |
22 | #include <stdexcept> |
23 | #include <vector> |
24 | |
25 | #include <ATen/core/TensorBody.h> |
26 | #include <ATen/core/function_schema.h> |
27 | #include <ATen/core/stack.h> |
28 | #include <ATen/record_function.h> |
29 | #include <c10/util/irange.h> |
30 | #include <torch/csrc/profiler/standalone/execution_graph_observer.h> |
31 | #include <torch/csrc/profiler/util.h> |
32 | |
33 | using namespace at; |
34 | |
35 | namespace torch { |
36 | namespace profiler { |
37 | namespace impl { |
38 | |
39 | //****************************************************************************** |
40 | // JSON output utility functions. To be merged with PyTorch profiler. |
41 | //****************************************************************************** |
42 | template <typename T> |
43 | inline std::string vectorToString(const std::vector<T>& v) { |
44 | return fmt::format("[{}]" , fmt::join(v, "," )); |
45 | } |
46 | |
47 | constexpr size_t maxNumElements = 4096; |
48 | |
49 | inline std::string getValueType( |
50 | const c10::IValue& val, |
51 | const bool baseType = true, |
52 | const size_t maxArrayLen = maxNumElements) { |
53 | std::string type = val.tagKind(); |
54 | |
55 | if (val.isTensor()) { |
56 | // Add tensor element data type. |
57 | type += fmt::format("({})" , std::string(val.toTensor().dtype().name())); |
58 | } else if (val.isTuple()) { |
59 | const auto& val_container = val.toTupleRef().elements(); |
60 | std::vector<std::string> str_array; |
61 | for (const auto& t : val_container) { |
62 | str_array.emplace_back(getValueType(t, false)); |
63 | } |
64 | type += vectorToString(str_array); |
65 | } else if (val.isList()) { |
66 | const auto& val_list = val.toList(); |
67 | std::vector<std::string> str_array; |
68 | str_array.reserve(val_list.size()); |
69 | for (const auto j : c10::irange(val_list.size())) { |
70 | str_array.push_back(getValueType(val_list.get(j), false)); |
71 | if (j >= maxArrayLen) { |
72 | LOG(WARNING) << "list size=" << val_list.size() |
73 | << " exceeded maxArrayLen=" << maxArrayLen; |
74 | break; |
75 | } |
76 | } |
77 | type += vectorToString(str_array); |
78 | } |
79 | return baseType ? fmt::format("\"{}\"" , type) : type; |
80 | } |
81 | |
82 | inline std::string getValueShape( |
83 | const c10::IValue& val, |
84 | const size_t maxArrayLen = maxNumElements) { |
85 | if (val.isTensor()) { |
86 | auto& tensor = val.toTensor(); |
87 | if (tensor.defined()) { |
88 | return vectorToString(tensor.sizes().vec()); |
89 | } |
90 | } else if (val.isTuple()) { |
91 | const auto& val_container = val.toTupleRef().elements(); |
92 | std::vector<std::string> str_array; |
93 | for (const auto& t : val_container) { |
94 | str_array.push_back(getValueShape(t)); |
95 | } |
96 | return vectorToString(str_array); |
97 | } else if (val.isList()) { |
98 | const auto& val_list = val.toList(); |
99 | std::vector<std::string> str_array; |
100 | str_array.reserve(val_list.size()); |
101 | for (const auto j : c10::irange(val_list.size())) { |
102 | str_array.push_back(getValueShape(val_list.get(j))); |
103 | if (j >= maxArrayLen) { |
104 | LOG(WARNING) << "list size=" << val_list.size() |
105 | << " exceeded maxArrayLen=" << maxArrayLen; |
106 | break; |
107 | } |
108 | } |
109 | return vectorToString(str_array); |
110 | } |
111 | return "[]" ; |
112 | } |
113 | |
114 | inline std::string getScalarValue(const c10::IValue& val) { |
115 | if (val.isDouble()) { |
116 | double d_val = val.toDouble(); |
117 | if (std::isinf(d_val) || std::isnan(d_val)) { |
118 | return fmt::format("\"{}\"" , std::to_string(d_val)); |
119 | } else { |
120 | return std::to_string(d_val); |
121 | } |
122 | } else if (val.isInt()) { |
123 | return std::to_string(val.toInt()); |
124 | } else if (val.isBool()) { |
125 | return val.toBool() ? "true" : "false" ; |
126 | } else if (val.isString()) { |
127 | const std::string& str_val = val.toStringRef(); |
128 | if (str_val.size() > maxNumElements) { |
129 | LOG(WARNING) << "string size=" << str_val.size() |
130 | << " exceeded maxNumElements=" << maxNumElements; |
131 | return fmt::format("\"{}\"" , str_val.substr(0, maxNumElements)); |
132 | } |
133 | |
134 | return fmt::format("\"{}\"" , str_val); |
135 | } else if (val.isDevice()) { |
136 | return fmt::format("\"{}\"" , val.toDevice().str()); |
137 | } |
138 | return fmt::format("\"<{}>\"" , val.tagKind()); |
139 | } |
140 | |
141 | inline int32_t processId() { |
142 | #ifndef _WIN32 |
143 | return static_cast<int32_t>(getpid()); |
144 | #else |
145 | return static_cast<int32_t>(GetCurrentProcessId()); |
146 | #endif |
147 | } |
148 | |
149 | //****************************************************************************** |
150 | // Main ExecutionGraphObserver implementation. |
151 | //****************************************************************************** |
152 | |
153 | // ExecutionGraphObserver contains all the states of the observer. Some of them |
154 | // are shared between the enter and exit RecordFunction call backs, some data |
155 | // like the `op_stack` may be accessed across different threads. So we should be |
156 | // careful about data races. A global mutex `g_mutex` is used avoid these races |
157 | // at the cost of performance in large number of threads situations. We may |
158 | // optimize this further to thread local, fine-grained locking, or use thread |
159 | // safe containers. |
160 | struct TORCH_API ExecutionGraphObserver { |
161 | using ID = size_t; |
162 | |
163 | // Mapping of each thread to its own operator stack |
164 | std::map<size_t, std::stack<ID>> op_stack{}; |
165 | // Uses the underlying TensorImpl object pointer as the key and map to its |
166 | // unique id. |
167 | std::map<void*, ID> object_id{}; |
168 | // Observer run state. |
169 | enum class RunState { uninitialized, disabled, enabled }; |
170 | |
171 | // Mutex for multithreaded access to the shared containers. |
172 | std::mutex g_mutex{}; |
173 | // Stream to write output JSON. |
174 | std::ofstream out{}; |
175 | |
176 | // Full path to the output file. |
177 | std::string file_name{}; |
178 | |
179 | // RecordFunction callback handle for this observer. |
180 | CallbackHandle cb_handle{INVALID_CALLBACK_HANDLE}; |
181 | |
182 | // Process ID. |
183 | int32_t pid{-1}; |
184 | std::string record_time{}; |
185 | |
186 | ExecutionGraphObserver() = default; |
187 | |
188 | // Returns a new unique ID. |
189 | ID getNewID() { |
190 | return id_++; |
191 | } |
192 | |
193 | RunState getState() const { |
194 | return state_; |
195 | } |
196 | |
197 | void setState(RunState newState) { |
198 | if (state_ == RunState::uninitialized || |
199 | callbackShouldBeEnabled(state_) != callbackShouldBeEnabled(newState)) { |
200 | if (callbackShouldBeEnabled(newState)) { |
201 | reenableCallback(cb_handle); |
202 | } else { |
203 | disableCallback(cb_handle); |
204 | } |
205 | } |
206 | state_ = newState; |
207 | } |
208 | |
209 | private: |
210 | static bool callbackShouldBeEnabled(RunState run_state) { |
211 | return run_state == ExecutionGraphObserver::RunState::enabled; |
212 | } |
213 | |
214 | // Must use accessors to change this so that we can keep the |
215 | // RecordFunction callback in sync with the state. |
216 | RunState state_{RunState::uninitialized}; |
217 | |
218 | // All tensors and operators have an unique id assigned. Increment id for each |
219 | // new tensor or operator node. |
220 | // 0 -> unintialized |
221 | // 1 -> root ID |
222 | // 2 ... -> regular node ID |
223 | std::atomic<ID> id_{2}; |
224 | }; |
225 | |
226 | // Using a singleton manager here to allow init and delete the observer object. |
227 | using ObserverManager = GlobalStateManager<ExecutionGraphObserver>; |
228 | |
229 | // Uninitialized node has id = 0 |
230 | const ExecutionGraphObserver::ID uninitialized_id{0}; |
231 | // Root node has id = 1 |
232 | const ExecutionGraphObserver::ID root_id{1}; |
233 | |
234 | struct FunctionCallContext : public ObserverContext { |
235 | std::string name; |
236 | ExecutionGraphObserver::ID op_id{uninitialized_id}; |
237 | ExecutionGraphObserver::ID parent_id{uninitialized_id}; |
238 | ExecutionGraphObserver::ID fw_parent_id{uninitialized_id}; |
239 | std::vector<std::string> input_types; |
240 | std::vector<std::string> input_shapes; |
241 | std::vector<std::string> input_values; |
242 | }; |
243 | |
244 | // Opens the json file to write the execution graph. |
245 | std::ofstream openOutputFile(const std::string& name) { |
246 | std::ofstream stream; |
247 | stream.open(name, std::ofstream::out | std::ofstream::trunc); |
248 | if (!stream) { |
249 | LOG(ERROR) << "Failed to open '" << name << "'" ; |
250 | } else { |
251 | VLOG(1) << "Writing PyTorch execution graph to: " << name; |
252 | } |
253 | return stream; |
254 | } |
255 | |
256 | void writeJsonNode( |
257 | std::ofstream& out, |
258 | const std::string& name, |
259 | const uint64_t id, |
260 | const uint64_t rf_id, |
261 | const uint64_t parent, |
262 | const uint64_t fw_parent, |
263 | const int64_t seq_id, |
264 | const uint64_t scope, |
265 | const uint64_t tid, |
266 | const uint64_t fw_tid, |
267 | const std::string& inputs = "[]" , |
268 | const std::string& input_shapes = "[]" , |
269 | const std::string& input_types = "[]" , |
270 | const std::string& outputs = "[]" , |
271 | const std::string& output_shapes = "[]" , |
272 | const std::string& output_types = "[]" , |
273 | const std::string& operator_schema = "" ) { |
274 | out << fmt::format( |
275 | R"JSON( |
276 | {{ |
277 | "name": "{}", "id": {}, "rf_id": {}, "parent": {}, "fw_parent": {}, "seq_id": {}, "scope": {}, "tid": {}, "fw_tid": {}, "op_schema": "{}", |
278 | "inputs": {}, "input_shapes": {}, "input_types": {}, |
279 | "outputs": {}, "output_shapes": {}, "output_types": {} |
280 | }})JSON" , |
281 | name, |
282 | id, |
283 | rf_id, |
284 | parent, |
285 | fw_parent, |
286 | seq_id, |
287 | scope, |
288 | tid, |
289 | fw_tid, |
290 | operator_schema, |
291 | inputs, |
292 | input_shapes, |
293 | input_types, |
294 | outputs, |
295 | output_shapes, |
296 | output_types); |
297 | } |
298 | |
299 | inline std::string timeString(const std::time_t timepoint) { |
300 | std::ostringstream oss; |
301 | oss << std::put_time(std::localtime(&timepoint), "%Y-%m-%d %X" ); |
302 | return oss.str(); |
303 | } |
304 | |
305 | bool initExecutionGraphStart(ExecutionGraphObserver& ob) { |
306 | ob.out = openOutputFile(ob.file_name); |
307 | // If somehow the output stream failed to open, finish observer here. |
308 | if (!ob.out) { |
309 | LOG(WARNING) << "Failed to open output file: " << ob.file_name; |
310 | return false; |
311 | } |
312 | |
313 | // Wall clock time for the first op collection time. |
314 | const auto current_time = std::chrono::system_clock::now(); |
315 | ob.record_time = |
316 | timeString(std::chrono::system_clock::to_time_t(current_time)); |
317 | // Start timestamp using steady_clock for measurement. |
318 | const auto timestamp = |
319 | std::chrono::duration_cast<std::chrono::milliseconds>( |
320 | std::chrono::steady_clock::now().time_since_epoch()) |
321 | .count(); |
322 | |
323 | ob.out << fmt::format( |
324 | R"JSON({{ |
325 | "schema": "1.0.1", "pid": {}, "time": "{}", "start_ts": {}, |
326 | "nodes": [)JSON" , |
327 | ob.pid, |
328 | ob.record_time, |
329 | timestamp); |
330 | return true; |
331 | } |
332 | |
333 | // Write out Execution Graph to file |
334 | void finalizeExecutionGraphOutput(ExecutionGraphObserver& ob) { |
335 | writeJsonNode( |
336 | ob.out, |
337 | "[pytorch|profiler|execution_graph|process]" , |
338 | root_id, |
339 | 0, // rf_id |
340 | root_id, // parent is self |
341 | 0, // fw_parent |
342 | -1, // seq_id |
343 | static_cast<std::underlying_type_t<RecordScope>>(RecordScope::USER_SCOPE), |
344 | 0, // tid |
345 | 0); // fw_tid |
346 | |
347 | // Finish timestamp using steady_clock for measurement. |
348 | const auto timestamp = |
349 | std::chrono::duration_cast<std::chrono::milliseconds>( |
350 | std::chrono::steady_clock::now().time_since_epoch()) |
351 | .count(); |
352 | ob.out << fmt::format( |
353 | R"JSON( |
354 | ], |
355 | "finish_ts": {} |
356 | }})JSON" , |
357 | timestamp); |
358 | |
359 | ob.out.close(); |
360 | VLOG(1) << "PyTorch execution graph is written to file: " << ob.file_name; |
361 | } |
362 | |
363 | inline ExecutionGraphObserver::ID getObjectID( |
364 | ExecutionGraphObserver& ob, |
365 | void* t) { |
366 | auto iter = ob.object_id.find(t); |
367 | if (iter == ob.object_id.end()) { |
368 | ExecutionGraphObserver::ID object_id = ob.getNewID(); |
369 | ob.object_id[t] = object_id; |
370 | return object_id; |
371 | } |
372 | |
373 | return iter->second; |
374 | } |
375 | |
376 | inline std::string convertIValue( |
377 | ExecutionGraphObserver& ob, |
378 | const c10::IValue& val, |
379 | const size_t maxArrayLen = maxNumElements) { |
380 | if (val.isTensor()) { |
381 | const auto t = val.toTensor().unsafeGetTensorImpl(); |
382 | ExecutionGraphObserver::ID tensor_id = getObjectID(ob, t); |
383 | ExecutionGraphObserver::ID storage_id = 0; |
384 | size_t offset = 0; |
385 | size_t numel = 0; |
386 | size_t itemsize = 0; |
387 | std::string device_str = "" ; |
388 | if (t->has_storage()) { |
389 | auto& t_storage = t->storage(); |
390 | storage_id = getObjectID(ob, t_storage.data()); |
391 | offset = t->storage_offset(); |
392 | numel = t->numel(); |
393 | itemsize = t->itemsize(); |
394 | device_str = t->device().str(); |
395 | } |
396 | return fmt::format( |
397 | "[{},{},{},{},{},\"{}\"]" , |
398 | tensor_id, |
399 | storage_id, |
400 | offset, |
401 | numel, |
402 | itemsize, |
403 | device_str); |
404 | } else if (val.isTuple()) { |
405 | std::vector<std::string> str_array; |
406 | const auto& val_tuple = val.toTupleRef().elements(); |
407 | for (const auto j : c10::irange(val_tuple.size())) { |
408 | str_array.push_back(convertIValue(ob, val_tuple[j])); |
409 | } |
410 | return vectorToString(str_array); |
411 | } else if (val.isList()) { |
412 | const auto& val_list = val.toList(); |
413 | std::vector<std::string> str_array; |
414 | str_array.reserve(val_list.size()); |
415 | for (const auto j : c10::irange(val_list.size())) { |
416 | str_array.push_back(convertIValue(ob, val_list.get(j))); |
417 | if (j >= maxArrayLen) { |
418 | LOG(WARNING) << "list size=" << val_list.size() |
419 | << " exceeded maxArrayLen=" << maxArrayLen; |
420 | break; |
421 | } |
422 | } |
423 | return vectorToString(str_array); |
424 | } else { |
425 | return getScalarValue(val); |
426 | } |
427 | } |
428 | |
429 | inline void appendValueInfo( |
430 | ExecutionGraphObserver& ob, |
431 | const c10::IValue& val, |
432 | std::vector<std::string>& values, |
433 | std::vector<std::string>& types, |
434 | std::vector<std::string>& shapes) { |
435 | values.push_back(convertIValue(ob, val)); |
436 | types.push_back(getValueType(val)); |
437 | shapes.push_back(getValueShape(val)); |
438 | } |
439 | |
440 | void recordOperatorStart( |
441 | ExecutionGraphObserver& ob, |
442 | FunctionCallContext& fc, |
443 | const RecordFunction& fn) { |
444 | auto tid = fn.threadId(); |
445 | |
446 | try { |
447 | const std::lock_guard<std::mutex> lock(ob.g_mutex); |
448 | |
449 | // if current thread stack is empty, push the root node to the stack first |
450 | if (ob.op_stack[tid].empty()) { |
451 | auto thread_node_id = ob.getNewID(); |
452 | ob.op_stack[tid].push(thread_node_id); |
453 | writeJsonNode( |
454 | ob.out, |
455 | "[pytorch|profiler|execution_graph|thread]" , |
456 | thread_node_id, |
457 | 0, // rf_id |
458 | root_id, |
459 | 0, // fw_parent |
460 | -1, // seq_id |
461 | static_cast<std::underlying_type_t<RecordScope>>( |
462 | RecordScope::USER_SCOPE), |
463 | tid, |
464 | 0); // fw_tid |
465 | ob.out << "," ; |
466 | } |
467 | fc.name = fn.name(); |
468 | auto num_inputs = fn.num_inputs(); |
469 | const auto inputs = fn.inputs(); |
470 | |
471 | VLOG(2) << "inputs: " << num_inputs << " " << inputs.size() << std::endl; |
472 | // We have two cases: for unboxed kernel, we have num_inputs == |
473 | // inputs.size() for boxed kernel using stack, there could be more elements |
474 | // on the stack from previous ops. |
475 | // TORCH_INTERNAL_ASSERT(num_inputs <= inputs.size()); |
476 | if (num_inputs > inputs.size()) { |
477 | LOG(WARNING) << "RecordFunction " << fc.name |
478 | << " expected num_inputs=" << num_inputs |
479 | << " > inputs.size()=" << inputs.size(); |
480 | return; |
481 | } |
482 | // need to account for Stack mode where the inputs are at the end. |
483 | size_t input_start = inputs.size() - num_inputs; |
484 | |
485 | for (const auto i : c10::irange(input_start, inputs.size())) { |
486 | appendValueInfo( |
487 | ob, inputs[i], fc.input_values, fc.input_types, fc.input_shapes); |
488 | } |
489 | fc.parent_id = ob.op_stack[tid].top(); |
490 | // get parent id from the forward stack, this can be different for |
491 | // autograd ops, which may execute on a different thread than the original |
492 | // thread (which should have the parent op on the stack). |
493 | auto fw_tid = fn.forwardThreadId(); |
494 | if (fw_tid != 0) { |
495 | fc.fw_parent_id = ob.op_stack[fw_tid].top(); |
496 | } |
497 | // all input nodes should have id > op_id |
498 | fc.op_id = ob.getNewID(); |
499 | ob.op_stack[tid].push(fc.op_id); |
500 | |
501 | } catch (const std::exception& e) { |
502 | LOG(WARNING) << "Exception in execution graph observer: " << e.what(); |
503 | } |
504 | } |
505 | |
506 | std::unique_ptr<ObserverContext> onFunctionEnter(const RecordFunction& fn) { |
507 | using RunState = ExecutionGraphObserver::RunState; |
508 | auto ob = ObserverManager::get(); |
509 | if (ob != nullptr && ob->getState() == RunState::enabled) { |
510 | // record op |
511 | auto fc_ptr = std::make_unique<FunctionCallContext>(); |
512 | recordOperatorStart(*ob, *fc_ptr.get(), fn); |
513 | return fc_ptr; |
514 | } |
515 | return nullptr; |
516 | } |
517 | |
518 | inline std::string json_str_escape(const std::string& str) { |
519 | std::ostringstream ostream; |
520 | for (char ch : str) { |
521 | if (ch == '"') { |
522 | ostream << "\\\"" ; |
523 | } else if (ch == '\\') { |
524 | ostream << "\\\\" ; |
525 | } else if (ch == '\b') { |
526 | ostream << "\\b" ; |
527 | } else if (ch == '\f') { |
528 | ostream << "\\f" ; |
529 | } else if (ch == '\n') { |
530 | ostream << "\\n" ; |
531 | } else if (ch == '\r') { |
532 | ostream << "\\r" ; |
533 | } else if (ch == '\t') { |
534 | ostream << "\\t" ; |
535 | } else if ('\x00' <= ch && ch <= '\x1f') { |
536 | ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0') |
537 | << static_cast<int>(ch); |
538 | } else { |
539 | ostream << ch; |
540 | } |
541 | } |
542 | return ostream.str(); |
543 | } |
544 | |
545 | void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) { |
546 | using RunState = ExecutionGraphObserver::RunState; |
547 | auto ob = ObserverManager::get(); |
548 | if (ob == nullptr || ctx_ptr == nullptr) { |
549 | return; |
550 | } |
551 | if (ob->getState() == RunState::enabled) { |
552 | auto fc_ptr = dynamic_cast<FunctionCallContext*>(ctx_ptr); |
553 | // TORCH_INTERNAL_ASSERT(fc_ptr != nullptr); |
554 | if (fc_ptr == nullptr) { |
555 | LOG(WARNING) << "FunctionCallContext is nullptr." ; |
556 | return; |
557 | } |
558 | auto& fc = *fc_ptr; |
559 | |
560 | auto outputs = fn.outputs(); |
561 | auto num_outputs = fn.num_outputs(); |
562 | // We have two cases: for unboxed kernel, we have num_outputs == |
563 | // outputs.size() for boxed kernel using stack, there could be more elements |
564 | // on the stack from previous ops. |
565 | VLOG(2) << "outputs: " << num_outputs << " " << outputs.size() << std::endl; |
566 | // TORCH_INTERNAL_ASSERT(num_outputs <= outputs.size()); |
567 | if (num_outputs > outputs.size()) { |
568 | LOG(WARNING) << "RecordFunction " << fc.name |
569 | << " num_outputs=" << num_outputs |
570 | << " > outputs.size()=" << outputs.size(); |
571 | return; |
572 | } |
573 | // need to account for Stack mode where the outputs are at the end. |
574 | size_t output_start = outputs.size() - num_outputs; |
575 | |
576 | std::vector<std::string> output_types; |
577 | std::vector<std::string> output_shapes; |
578 | std::vector<std::string> output_values; |
579 | try { |
580 | const std::lock_guard<std::mutex> lock(ob->g_mutex); |
581 | // remove current op id from stack |
582 | |
583 | ob->op_stack[fn.threadId()].pop(); |
584 | for (const auto i : c10::irange(output_start, outputs.size())) { |
585 | appendValueInfo( |
586 | *ob, outputs[i], output_values, output_types, output_shapes); |
587 | } |
588 | |
589 | std::string op_schema_str{}; |
590 | const auto op_schema = fn.operator_schema(); |
591 | if (op_schema.has_value()) { |
592 | op_schema_str = json_str_escape(c10::toString(op_schema.value())); |
593 | } |
594 | |
595 | writeJsonNode( |
596 | ob->out, |
597 | fc.name, |
598 | fc.op_id, |
599 | fn.handle(), |
600 | fc.parent_id, |
601 | fc.fw_parent_id, |
602 | fn.seqNr(), |
603 | static_cast<std::underlying_type_t<RecordScope>>(fn.scope()), |
604 | fn.threadId(), |
605 | fn.forwardThreadId(), |
606 | vectorToString(fc.input_values), |
607 | vectorToString(fc.input_shapes), |
608 | vectorToString(fc.input_types), |
609 | vectorToString(output_values), |
610 | vectorToString(output_shapes), |
611 | vectorToString(output_types), |
612 | op_schema_str); |
613 | ob->out << "," ; |
614 | } catch (const std::exception& e) { |
615 | LOG(WARNING) << "Exception in execution graph observer: [" << fc.name |
616 | << " (" << fc.op_id << ")] " << e.what(); |
617 | } |
618 | } |
619 | } |
620 | |
621 | // Add execution graph observer callback functions to the RecordFunction global |
622 | // observers. |
623 | bool addExecutionGraphObserver(const std::string& output_file_path) { |
624 | // Check if the observer is already initialized. |
625 | if (ObserverManager::get() == nullptr) { |
626 | ObserverManager::push(std::make_shared<ExecutionGraphObserver>()); |
627 | auto& ob = *ObserverManager::get(); |
628 | ob.pid = processId(); |
629 | // Set output |
630 | ob.file_name = output_file_path; |
631 | if (!initExecutionGraphStart(ob)) { |
632 | return false; |
633 | } |
634 | |
635 | ob.cb_handle = addGlobalCallback( |
636 | RecordFunctionCallback(&onFunctionEnter, &onFunctionExit) |
637 | .needsInputs(true) |
638 | .needsOutputs(true) |
639 | .needsIds(true)); |
640 | // Default to disabled. |
641 | ob.setState(ExecutionGraphObserver::RunState::disabled); |
642 | |
643 | VLOG(1) << "Added PyTorch execution graph observer, output=" |
644 | << output_file_path; |
645 | } else if (ObserverManager::get()->cb_handle != INVALID_CALLBACK_HANDLE) { |
646 | LOG(WARNING) << "Execution graph observer is already registered." ; |
647 | } |
648 | return true; |
649 | } |
650 | |
651 | void removeExecutionGraphObserver() { |
652 | auto ob = ObserverManager::get(); |
653 | if (ob != nullptr) { |
654 | if (ob->getState() != ExecutionGraphObserver::RunState::disabled) { |
655 | disableExecutionGraphObserver(); |
656 | } |
657 | |
658 | if (ob->cb_handle != INVALID_CALLBACK_HANDLE) { |
659 | finalizeExecutionGraphOutput(*ob); |
660 | removeCallback(ob->cb_handle); |
661 | ob->cb_handle = INVALID_CALLBACK_HANDLE; |
662 | // Release the current EG observer object and reset. |
663 | TORCH_INTERNAL_ASSERT( |
664 | ObserverManager::pop() != nullptr, |
665 | "Global state ptr cannot be null before resetting" ); |
666 | VLOG(1) << "Removed PyTorch execution graph observer" ; |
667 | } else { |
668 | LOG(WARNING) << "Execution graph observer was not registered." ; |
669 | } |
670 | } else { |
671 | LOG(WARNING) << "Execution graph observer was not initialized." ; |
672 | } |
673 | } |
674 | |
675 | void enableExecutionGraphObserver() { |
676 | VLOG(1) << "enableExecutionGraphObserver() " ; |
677 | auto& ob = *ObserverManager::get(); |
678 | // Make sure we are not already enabled. |
679 | if (ob.getState() == ExecutionGraphObserver::RunState::enabled) { |
680 | LOG(WARNING) |
681 | << "Trying to enable Execution Graph Observer when it's already enabled." ; |
682 | } else { |
683 | ob.setState(ExecutionGraphObserver::RunState::enabled); |
684 | } |
685 | } |
686 | |
687 | void disableExecutionGraphObserver() { |
688 | VLOG(1) << "disableExecutionGraphObserver()" ; |
689 | auto& ob = *ObserverManager::get(); |
690 | if (ob.getState() != ExecutionGraphObserver::RunState::disabled) { |
691 | ob.setState(ExecutionGraphObserver::RunState::disabled); |
692 | } else { |
693 | LOG(WARNING) |
694 | << "Trying to disable Execution Graph Observer when it's already disabled." ; |
695 | } |
696 | } |
697 | } // namespace impl |
698 | } // namespace profiler |
699 | } // namespace torch |
700 | |