1#ifdef _WIN32
2#ifndef WIN32_LEAN_AND_MEAN
3#define WIN32_LEAN_AND_MEAN
4#endif
5#include <windows.h>
6
7#include <processthreadsapi.h>
8#else
9#include <unistd.h>
10#endif // _WIN32
11
12#include <fmt/format.h>
13#include <algorithm>
14#include <chrono>
15#include <cmath>
16#include <fstream>
17#include <iomanip>
18#include <map>
19#include <mutex>
20#include <sstream>
21#include <stack>
22#include <stdexcept>
23#include <vector>
24
25#include <ATen/core/TensorBody.h>
26#include <ATen/core/function_schema.h>
27#include <ATen/core/stack.h>
28#include <ATen/record_function.h>
29#include <c10/util/irange.h>
30#include <torch/csrc/profiler/standalone/execution_graph_observer.h>
31#include <torch/csrc/profiler/util.h>
32
33using namespace at;
34
35namespace torch {
36namespace profiler {
37namespace impl {
38
39//******************************************************************************
40// JSON output utility functions. To be merged with PyTorch profiler.
41//******************************************************************************
42template <typename T>
43inline std::string vectorToString(const std::vector<T>& v) {
44 return fmt::format("[{}]", fmt::join(v, ","));
45}
46
47constexpr size_t maxNumElements = 4096;
48
49inline std::string getValueType(
50 const c10::IValue& val,
51 const bool baseType = true,
52 const size_t maxArrayLen = maxNumElements) {
53 std::string type = val.tagKind();
54
55 if (val.isTensor()) {
56 // Add tensor element data type.
57 type += fmt::format("({})", std::string(val.toTensor().dtype().name()));
58 } else if (val.isTuple()) {
59 const auto& val_container = val.toTupleRef().elements();
60 std::vector<std::string> str_array;
61 for (const auto& t : val_container) {
62 str_array.emplace_back(getValueType(t, false));
63 }
64 type += vectorToString(str_array);
65 } else if (val.isList()) {
66 const auto& val_list = val.toList();
67 std::vector<std::string> str_array;
68 str_array.reserve(val_list.size());
69 for (const auto j : c10::irange(val_list.size())) {
70 str_array.push_back(getValueType(val_list.get(j), false));
71 if (j >= maxArrayLen) {
72 LOG(WARNING) << "list size=" << val_list.size()
73 << " exceeded maxArrayLen=" << maxArrayLen;
74 break;
75 }
76 }
77 type += vectorToString(str_array);
78 }
79 return baseType ? fmt::format("\"{}\"", type) : type;
80}
81
82inline std::string getValueShape(
83 const c10::IValue& val,
84 const size_t maxArrayLen = maxNumElements) {
85 if (val.isTensor()) {
86 auto& tensor = val.toTensor();
87 if (tensor.defined()) {
88 return vectorToString(tensor.sizes().vec());
89 }
90 } else if (val.isTuple()) {
91 const auto& val_container = val.toTupleRef().elements();
92 std::vector<std::string> str_array;
93 for (const auto& t : val_container) {
94 str_array.push_back(getValueShape(t));
95 }
96 return vectorToString(str_array);
97 } else if (val.isList()) {
98 const auto& val_list = val.toList();
99 std::vector<std::string> str_array;
100 str_array.reserve(val_list.size());
101 for (const auto j : c10::irange(val_list.size())) {
102 str_array.push_back(getValueShape(val_list.get(j)));
103 if (j >= maxArrayLen) {
104 LOG(WARNING) << "list size=" << val_list.size()
105 << " exceeded maxArrayLen=" << maxArrayLen;
106 break;
107 }
108 }
109 return vectorToString(str_array);
110 }
111 return "[]";
112}
113
114inline std::string getScalarValue(const c10::IValue& val) {
115 if (val.isDouble()) {
116 double d_val = val.toDouble();
117 if (std::isinf(d_val) || std::isnan(d_val)) {
118 return fmt::format("\"{}\"", std::to_string(d_val));
119 } else {
120 return std::to_string(d_val);
121 }
122 } else if (val.isInt()) {
123 return std::to_string(val.toInt());
124 } else if (val.isBool()) {
125 return val.toBool() ? "true" : "false";
126 } else if (val.isString()) {
127 const std::string& str_val = val.toStringRef();
128 if (str_val.size() > maxNumElements) {
129 LOG(WARNING) << "string size=" << str_val.size()
130 << " exceeded maxNumElements=" << maxNumElements;
131 return fmt::format("\"{}\"", str_val.substr(0, maxNumElements));
132 }
133
134 return fmt::format("\"{}\"", str_val);
135 } else if (val.isDevice()) {
136 return fmt::format("\"{}\"", val.toDevice().str());
137 }
138 return fmt::format("\"<{}>\"", val.tagKind());
139}
140
141inline int32_t processId() {
142#ifndef _WIN32
143 return static_cast<int32_t>(getpid());
144#else
145 return static_cast<int32_t>(GetCurrentProcessId());
146#endif
147}
148
149//******************************************************************************
150// Main ExecutionGraphObserver implementation.
151//******************************************************************************
152
153// ExecutionGraphObserver contains all the states of the observer. Some of them
154// are shared between the enter and exit RecordFunction call backs, some data
155// like the `op_stack` may be accessed across different threads. So we should be
156// careful about data races. A global mutex `g_mutex` is used avoid these races
157// at the cost of performance in large number of threads situations. We may
158// optimize this further to thread local, fine-grained locking, or use thread
159// safe containers.
160struct TORCH_API ExecutionGraphObserver {
161 using ID = size_t;
162
163 // Mapping of each thread to its own operator stack
164 std::map<size_t, std::stack<ID>> op_stack{};
165 // Uses the underlying TensorImpl object pointer as the key and map to its
166 // unique id.
167 std::map<void*, ID> object_id{};
168 // Observer run state.
169 enum class RunState { uninitialized, disabled, enabled };
170
171 // Mutex for multithreaded access to the shared containers.
172 std::mutex g_mutex{};
173 // Stream to write output JSON.
174 std::ofstream out{};
175
176 // Full path to the output file.
177 std::string file_name{};
178
179 // RecordFunction callback handle for this observer.
180 CallbackHandle cb_handle{INVALID_CALLBACK_HANDLE};
181
182 // Process ID.
183 int32_t pid{-1};
184 std::string record_time{};
185
186 ExecutionGraphObserver() = default;
187
188 // Returns a new unique ID.
189 ID getNewID() {
190 return id_++;
191 }
192
193 RunState getState() const {
194 return state_;
195 }
196
197 void setState(RunState newState) {
198 if (state_ == RunState::uninitialized ||
199 callbackShouldBeEnabled(state_) != callbackShouldBeEnabled(newState)) {
200 if (callbackShouldBeEnabled(newState)) {
201 reenableCallback(cb_handle);
202 } else {
203 disableCallback(cb_handle);
204 }
205 }
206 state_ = newState;
207 }
208
209 private:
210 static bool callbackShouldBeEnabled(RunState run_state) {
211 return run_state == ExecutionGraphObserver::RunState::enabled;
212 }
213
214 // Must use accessors to change this so that we can keep the
215 // RecordFunction callback in sync with the state.
216 RunState state_{RunState::uninitialized};
217
218 // All tensors and operators have an unique id assigned. Increment id for each
219 // new tensor or operator node.
220 // 0 -> unintialized
221 // 1 -> root ID
222 // 2 ... -> regular node ID
223 std::atomic<ID> id_{2};
224};
225
226// Using a singleton manager here to allow init and delete the observer object.
227using ObserverManager = GlobalStateManager<ExecutionGraphObserver>;
228
229// Uninitialized node has id = 0
230const ExecutionGraphObserver::ID uninitialized_id{0};
231// Root node has id = 1
232const ExecutionGraphObserver::ID root_id{1};
233
234struct FunctionCallContext : public ObserverContext {
235 std::string name;
236 ExecutionGraphObserver::ID op_id{uninitialized_id};
237 ExecutionGraphObserver::ID parent_id{uninitialized_id};
238 ExecutionGraphObserver::ID fw_parent_id{uninitialized_id};
239 std::vector<std::string> input_types;
240 std::vector<std::string> input_shapes;
241 std::vector<std::string> input_values;
242};
243
244// Opens the json file to write the execution graph.
245std::ofstream openOutputFile(const std::string& name) {
246 std::ofstream stream;
247 stream.open(name, std::ofstream::out | std::ofstream::trunc);
248 if (!stream) {
249 LOG(ERROR) << "Failed to open '" << name << "'";
250 } else {
251 VLOG(1) << "Writing PyTorch execution graph to: " << name;
252 }
253 return stream;
254}
255
256void writeJsonNode(
257 std::ofstream& out,
258 const std::string& name,
259 const uint64_t id,
260 const uint64_t rf_id,
261 const uint64_t parent,
262 const uint64_t fw_parent,
263 const int64_t seq_id,
264 const uint64_t scope,
265 const uint64_t tid,
266 const uint64_t fw_tid,
267 const std::string& inputs = "[]",
268 const std::string& input_shapes = "[]",
269 const std::string& input_types = "[]",
270 const std::string& outputs = "[]",
271 const std::string& output_shapes = "[]",
272 const std::string& output_types = "[]",
273 const std::string& operator_schema = "") {
274 out << fmt::format(
275 R"JSON(
276 {{
277 "name": "{}", "id": {}, "rf_id": {}, "parent": {}, "fw_parent": {}, "seq_id": {}, "scope": {}, "tid": {}, "fw_tid": {}, "op_schema": "{}",
278 "inputs": {}, "input_shapes": {}, "input_types": {},
279 "outputs": {}, "output_shapes": {}, "output_types": {}
280 }})JSON",
281 name,
282 id,
283 rf_id,
284 parent,
285 fw_parent,
286 seq_id,
287 scope,
288 tid,
289 fw_tid,
290 operator_schema,
291 inputs,
292 input_shapes,
293 input_types,
294 outputs,
295 output_shapes,
296 output_types);
297}
298
299inline std::string timeString(const std::time_t timepoint) {
300 std::ostringstream oss;
301 oss << std::put_time(std::localtime(&timepoint), "%Y-%m-%d %X");
302 return oss.str();
303}
304
305bool initExecutionGraphStart(ExecutionGraphObserver& ob) {
306 ob.out = openOutputFile(ob.file_name);
307 // If somehow the output stream failed to open, finish observer here.
308 if (!ob.out) {
309 LOG(WARNING) << "Failed to open output file: " << ob.file_name;
310 return false;
311 }
312
313 // Wall clock time for the first op collection time.
314 const auto current_time = std::chrono::system_clock::now();
315 ob.record_time =
316 timeString(std::chrono::system_clock::to_time_t(current_time));
317 // Start timestamp using steady_clock for measurement.
318 const auto timestamp =
319 std::chrono::duration_cast<std::chrono::milliseconds>(
320 std::chrono::steady_clock::now().time_since_epoch())
321 .count();
322
323 ob.out << fmt::format(
324 R"JSON({{
325 "schema": "1.0.1", "pid": {}, "time": "{}", "start_ts": {},
326 "nodes": [)JSON",
327 ob.pid,
328 ob.record_time,
329 timestamp);
330 return true;
331}
332
333// Write out Execution Graph to file
334void finalizeExecutionGraphOutput(ExecutionGraphObserver& ob) {
335 writeJsonNode(
336 ob.out,
337 "[pytorch|profiler|execution_graph|process]",
338 root_id,
339 0, // rf_id
340 root_id, // parent is self
341 0, // fw_parent
342 -1, // seq_id
343 static_cast<std::underlying_type_t<RecordScope>>(RecordScope::USER_SCOPE),
344 0, // tid
345 0); // fw_tid
346
347 // Finish timestamp using steady_clock for measurement.
348 const auto timestamp =
349 std::chrono::duration_cast<std::chrono::milliseconds>(
350 std::chrono::steady_clock::now().time_since_epoch())
351 .count();
352 ob.out << fmt::format(
353 R"JSON(
354 ],
355 "finish_ts": {}
356}})JSON",
357 timestamp);
358
359 ob.out.close();
360 VLOG(1) << "PyTorch execution graph is written to file: " << ob.file_name;
361}
362
363inline ExecutionGraphObserver::ID getObjectID(
364 ExecutionGraphObserver& ob,
365 void* t) {
366 auto iter = ob.object_id.find(t);
367 if (iter == ob.object_id.end()) {
368 ExecutionGraphObserver::ID object_id = ob.getNewID();
369 ob.object_id[t] = object_id;
370 return object_id;
371 }
372
373 return iter->second;
374}
375
376inline std::string convertIValue(
377 ExecutionGraphObserver& ob,
378 const c10::IValue& val,
379 const size_t maxArrayLen = maxNumElements) {
380 if (val.isTensor()) {
381 const auto t = val.toTensor().unsafeGetTensorImpl();
382 ExecutionGraphObserver::ID tensor_id = getObjectID(ob, t);
383 ExecutionGraphObserver::ID storage_id = 0;
384 size_t offset = 0;
385 size_t numel = 0;
386 size_t itemsize = 0;
387 std::string device_str = "";
388 if (t->has_storage()) {
389 auto& t_storage = t->storage();
390 storage_id = getObjectID(ob, t_storage.data());
391 offset = t->storage_offset();
392 numel = t->numel();
393 itemsize = t->itemsize();
394 device_str = t->device().str();
395 }
396 return fmt::format(
397 "[{},{},{},{},{},\"{}\"]",
398 tensor_id,
399 storage_id,
400 offset,
401 numel,
402 itemsize,
403 device_str);
404 } else if (val.isTuple()) {
405 std::vector<std::string> str_array;
406 const auto& val_tuple = val.toTupleRef().elements();
407 for (const auto j : c10::irange(val_tuple.size())) {
408 str_array.push_back(convertIValue(ob, val_tuple[j]));
409 }
410 return vectorToString(str_array);
411 } else if (val.isList()) {
412 const auto& val_list = val.toList();
413 std::vector<std::string> str_array;
414 str_array.reserve(val_list.size());
415 for (const auto j : c10::irange(val_list.size())) {
416 str_array.push_back(convertIValue(ob, val_list.get(j)));
417 if (j >= maxArrayLen) {
418 LOG(WARNING) << "list size=" << val_list.size()
419 << " exceeded maxArrayLen=" << maxArrayLen;
420 break;
421 }
422 }
423 return vectorToString(str_array);
424 } else {
425 return getScalarValue(val);
426 }
427}
428
429inline void appendValueInfo(
430 ExecutionGraphObserver& ob,
431 const c10::IValue& val,
432 std::vector<std::string>& values,
433 std::vector<std::string>& types,
434 std::vector<std::string>& shapes) {
435 values.push_back(convertIValue(ob, val));
436 types.push_back(getValueType(val));
437 shapes.push_back(getValueShape(val));
438}
439
440void recordOperatorStart(
441 ExecutionGraphObserver& ob,
442 FunctionCallContext& fc,
443 const RecordFunction& fn) {
444 auto tid = fn.threadId();
445
446 try {
447 const std::lock_guard<std::mutex> lock(ob.g_mutex);
448
449 // if current thread stack is empty, push the root node to the stack first
450 if (ob.op_stack[tid].empty()) {
451 auto thread_node_id = ob.getNewID();
452 ob.op_stack[tid].push(thread_node_id);
453 writeJsonNode(
454 ob.out,
455 "[pytorch|profiler|execution_graph|thread]",
456 thread_node_id,
457 0, // rf_id
458 root_id,
459 0, // fw_parent
460 -1, // seq_id
461 static_cast<std::underlying_type_t<RecordScope>>(
462 RecordScope::USER_SCOPE),
463 tid,
464 0); // fw_tid
465 ob.out << ",";
466 }
467 fc.name = fn.name();
468 auto num_inputs = fn.num_inputs();
469 const auto inputs = fn.inputs();
470
471 VLOG(2) << "inputs: " << num_inputs << " " << inputs.size() << std::endl;
472 // We have two cases: for unboxed kernel, we have num_inputs ==
473 // inputs.size() for boxed kernel using stack, there could be more elements
474 // on the stack from previous ops.
475 // TORCH_INTERNAL_ASSERT(num_inputs <= inputs.size());
476 if (num_inputs > inputs.size()) {
477 LOG(WARNING) << "RecordFunction " << fc.name
478 << " expected num_inputs=" << num_inputs
479 << " > inputs.size()=" << inputs.size();
480 return;
481 }
482 // need to account for Stack mode where the inputs are at the end.
483 size_t input_start = inputs.size() - num_inputs;
484
485 for (const auto i : c10::irange(input_start, inputs.size())) {
486 appendValueInfo(
487 ob, inputs[i], fc.input_values, fc.input_types, fc.input_shapes);
488 }
489 fc.parent_id = ob.op_stack[tid].top();
490 // get parent id from the forward stack, this can be different for
491 // autograd ops, which may execute on a different thread than the original
492 // thread (which should have the parent op on the stack).
493 auto fw_tid = fn.forwardThreadId();
494 if (fw_tid != 0) {
495 fc.fw_parent_id = ob.op_stack[fw_tid].top();
496 }
497 // all input nodes should have id > op_id
498 fc.op_id = ob.getNewID();
499 ob.op_stack[tid].push(fc.op_id);
500
501 } catch (const std::exception& e) {
502 LOG(WARNING) << "Exception in execution graph observer: " << e.what();
503 }
504}
505
506std::unique_ptr<ObserverContext> onFunctionEnter(const RecordFunction& fn) {
507 using RunState = ExecutionGraphObserver::RunState;
508 auto ob = ObserverManager::get();
509 if (ob != nullptr && ob->getState() == RunState::enabled) {
510 // record op
511 auto fc_ptr = std::make_unique<FunctionCallContext>();
512 recordOperatorStart(*ob, *fc_ptr.get(), fn);
513 return fc_ptr;
514 }
515 return nullptr;
516}
517
518inline std::string json_str_escape(const std::string& str) {
519 std::ostringstream ostream;
520 for (char ch : str) {
521 if (ch == '"') {
522 ostream << "\\\"";
523 } else if (ch == '\\') {
524 ostream << "\\\\";
525 } else if (ch == '\b') {
526 ostream << "\\b";
527 } else if (ch == '\f') {
528 ostream << "\\f";
529 } else if (ch == '\n') {
530 ostream << "\\n";
531 } else if (ch == '\r') {
532 ostream << "\\r";
533 } else if (ch == '\t') {
534 ostream << "\\t";
535 } else if ('\x00' <= ch && ch <= '\x1f') {
536 ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
537 << static_cast<int>(ch);
538 } else {
539 ostream << ch;
540 }
541 }
542 return ostream.str();
543}
544
545void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
546 using RunState = ExecutionGraphObserver::RunState;
547 auto ob = ObserverManager::get();
548 if (ob == nullptr || ctx_ptr == nullptr) {
549 return;
550 }
551 if (ob->getState() == RunState::enabled) {
552 auto fc_ptr = dynamic_cast<FunctionCallContext*>(ctx_ptr);
553 // TORCH_INTERNAL_ASSERT(fc_ptr != nullptr);
554 if (fc_ptr == nullptr) {
555 LOG(WARNING) << "FunctionCallContext is nullptr.";
556 return;
557 }
558 auto& fc = *fc_ptr;
559
560 auto outputs = fn.outputs();
561 auto num_outputs = fn.num_outputs();
562 // We have two cases: for unboxed kernel, we have num_outputs ==
563 // outputs.size() for boxed kernel using stack, there could be more elements
564 // on the stack from previous ops.
565 VLOG(2) << "outputs: " << num_outputs << " " << outputs.size() << std::endl;
566 // TORCH_INTERNAL_ASSERT(num_outputs <= outputs.size());
567 if (num_outputs > outputs.size()) {
568 LOG(WARNING) << "RecordFunction " << fc.name
569 << " num_outputs=" << num_outputs
570 << " > outputs.size()=" << outputs.size();
571 return;
572 }
573 // need to account for Stack mode where the outputs are at the end.
574 size_t output_start = outputs.size() - num_outputs;
575
576 std::vector<std::string> output_types;
577 std::vector<std::string> output_shapes;
578 std::vector<std::string> output_values;
579 try {
580 const std::lock_guard<std::mutex> lock(ob->g_mutex);
581 // remove current op id from stack
582
583 ob->op_stack[fn.threadId()].pop();
584 for (const auto i : c10::irange(output_start, outputs.size())) {
585 appendValueInfo(
586 *ob, outputs[i], output_values, output_types, output_shapes);
587 }
588
589 std::string op_schema_str{};
590 const auto op_schema = fn.operator_schema();
591 if (op_schema.has_value()) {
592 op_schema_str = json_str_escape(c10::toString(op_schema.value()));
593 }
594
595 writeJsonNode(
596 ob->out,
597 fc.name,
598 fc.op_id,
599 fn.handle(),
600 fc.parent_id,
601 fc.fw_parent_id,
602 fn.seqNr(),
603 static_cast<std::underlying_type_t<RecordScope>>(fn.scope()),
604 fn.threadId(),
605 fn.forwardThreadId(),
606 vectorToString(fc.input_values),
607 vectorToString(fc.input_shapes),
608 vectorToString(fc.input_types),
609 vectorToString(output_values),
610 vectorToString(output_shapes),
611 vectorToString(output_types),
612 op_schema_str);
613 ob->out << ",";
614 } catch (const std::exception& e) {
615 LOG(WARNING) << "Exception in execution graph observer: [" << fc.name
616 << " (" << fc.op_id << ")] " << e.what();
617 }
618 }
619}
620
621// Add execution graph observer callback functions to the RecordFunction global
622// observers.
623bool addExecutionGraphObserver(const std::string& output_file_path) {
624 // Check if the observer is already initialized.
625 if (ObserverManager::get() == nullptr) {
626 ObserverManager::push(std::make_shared<ExecutionGraphObserver>());
627 auto& ob = *ObserverManager::get();
628 ob.pid = processId();
629 // Set output
630 ob.file_name = output_file_path;
631 if (!initExecutionGraphStart(ob)) {
632 return false;
633 }
634
635 ob.cb_handle = addGlobalCallback(
636 RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
637 .needsInputs(true)
638 .needsOutputs(true)
639 .needsIds(true));
640 // Default to disabled.
641 ob.setState(ExecutionGraphObserver::RunState::disabled);
642
643 VLOG(1) << "Added PyTorch execution graph observer, output="
644 << output_file_path;
645 } else if (ObserverManager::get()->cb_handle != INVALID_CALLBACK_HANDLE) {
646 LOG(WARNING) << "Execution graph observer is already registered.";
647 }
648 return true;
649}
650
651void removeExecutionGraphObserver() {
652 auto ob = ObserverManager::get();
653 if (ob != nullptr) {
654 if (ob->getState() != ExecutionGraphObserver::RunState::disabled) {
655 disableExecutionGraphObserver();
656 }
657
658 if (ob->cb_handle != INVALID_CALLBACK_HANDLE) {
659 finalizeExecutionGraphOutput(*ob);
660 removeCallback(ob->cb_handle);
661 ob->cb_handle = INVALID_CALLBACK_HANDLE;
662 // Release the current EG observer object and reset.
663 TORCH_INTERNAL_ASSERT(
664 ObserverManager::pop() != nullptr,
665 "Global state ptr cannot be null before resetting");
666 VLOG(1) << "Removed PyTorch execution graph observer";
667 } else {
668 LOG(WARNING) << "Execution graph observer was not registered.";
669 }
670 } else {
671 LOG(WARNING) << "Execution graph observer was not initialized.";
672 }
673}
674
675void enableExecutionGraphObserver() {
676 VLOG(1) << "enableExecutionGraphObserver() ";
677 auto& ob = *ObserverManager::get();
678 // Make sure we are not already enabled.
679 if (ob.getState() == ExecutionGraphObserver::RunState::enabled) {
680 LOG(WARNING)
681 << "Trying to enable Execution Graph Observer when it's already enabled.";
682 } else {
683 ob.setState(ExecutionGraphObserver::RunState::enabled);
684 }
685}
686
687void disableExecutionGraphObserver() {
688 VLOG(1) << "disableExecutionGraphObserver()";
689 auto& ob = *ObserverManager::get();
690 if (ob.getState() != ExecutionGraphObserver::RunState::disabled) {
691 ob.setState(ExecutionGraphObserver::RunState::disabled);
692 } else {
693 LOG(WARNING)
694 << "Trying to disable Execution Graph Observer when it's already disabled.";
695 }
696}
697} // namespace impl
698} // namespace profiler
699} // namespace torch
700