1 | #ifdef USE_C10D_UCC |
2 | |
3 | #include <torch/csrc/distributed/c10d/UCCTracing.hpp> |
4 | #include <torch/csrc/distributed/c10d/UCCUtils.hpp> |
5 | |
6 | #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp> |
7 | |
8 | #include <sys/stat.h> |
9 | #include <cstdlib> |
10 | #include <ctime> |
11 | #include <fstream> |
12 | |
13 | namespace c10d { |
14 | |
15 | void ProcessGroupUCCLogger::initCommsTracer() { |
16 | trace_generator = std::make_shared<CommTraceLogger>(); |
17 | initialized_CommTraceLogger = true; |
18 | } |
19 | |
20 | void ProcessGroupUCCLogger::flushComms(int rank, int world_size) { |
21 | if (!initialized_CommTraceLogger || |
22 | trace_generator->getCommsTrace().empty()) { |
23 | return; |
24 | } |
25 | |
26 | std::string dirname = c10::str("ProcessGroupUCC_trace_np" , world_size); |
27 | time_t now_ = time(0); |
28 | std::tm* ltm = localtime(&now_); |
29 | if (ltm) { |
30 | dirname += c10::str( |
31 | "_" , (1 + ltm->tm_mon), "_" , ltm->tm_mday, "_" , (1900 + ltm->tm_year)); |
32 | } |
33 | |
34 | std::string fullpath = "/tmp/" + dirname; |
35 | char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR" ); |
36 | if (user_path) { |
37 | fullpath = user_path; |
38 | } |
39 | std::string trace_filename = c10::str(fullpath, "/rank" , rank, ".json" ); |
40 | std::ofstream _outfile; |
41 | if (!_outfile.is_open()) { |
42 | if (!mkdir(fullpath.c_str(), 0777)) { |
43 | LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath; |
44 | } else if (errno != EEXIST) { |
45 | return; |
46 | } |
47 | _outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc); |
48 | } |
49 | // flush the traced comms |
50 | if (_outfile.is_open()) { |
51 | _outfile << "[" << c10::Join("," , trace_generator->getCommsTrace()) |
52 | << "\n]" ; |
53 | _outfile.flush(); |
54 | _outfile.close(); |
55 | } |
56 | } |
57 | |
58 | /* unused */ |
59 | void CommTraceLogger::setCurBlock(const std::string& name) { |
60 | curBlocks_.push_back( |
61 | c10::str("\"" , name, "\"" )); // add quote marks for JSON format |
62 | } |
63 | |
64 | /* unused */ |
65 | void CommTraceLogger::popBlock() { |
66 | // TODO: remove specific name |
67 | curBlocks_.pop_back(); |
68 | } |
69 | |
70 | void CommTraceLogger::recordOptionalInfo(int root) { |
71 | curRoot_ = root; |
72 | } |
73 | |
74 | void CommTraceLogger::recordOptionalInfo( |
75 | const std::vector<int64_t>& outputSplitSizes, |
76 | const std::vector<int64_t>& inputSplitSizes) { |
77 | curOutSplitSizes_ = outputSplitSizes; |
78 | curInSplitSizes_ = inputSplitSizes; |
79 | } |
80 | |
81 | void CommTraceLogger::recordComms( |
82 | const std::string& commName, |
83 | const uintptr_t workReq, |
84 | const int rank, |
85 | const int world_size, |
86 | const std::vector<at::Tensor>& inputTensors, |
87 | const std::vector<at::Tensor>& outputTensors) { |
88 | auto inSize = (!inputTensors.empty()) ? inputTensors[0].numel() : 0; |
89 | auto outSize = (!outputTensors.empty()) ? outputTensors[0].numel() : 0; |
90 | auto dtype = |
91 | (!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte; |
92 | auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type() |
93 | : c10::DeviceType::CPU; |
94 | auto now = std::chrono::system_clock::now(); |
95 | static auto startTS = now; |
96 | int64_t time_since_begin = |
97 | std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS) |
98 | .count(); |
99 | |
100 | // TODO: get markers from torch profiler if enabled |
101 | |
102 | // common fields for all operations |
103 | std::string cur_trace_ = c10::str( |
104 | "\n\t\t\"markers\": [" , |
105 | curBlocks_, |
106 | "]" , |
107 | ",\n\t\t\"startTime_ns\": " , |
108 | time_since_begin, |
109 | ",\n\t\t\"comms\": \"" , |
110 | commName, |
111 | "\"" , |
112 | ",\n\t\t\"req\": " , |
113 | workReq, |
114 | ",\n\t\t\"seqnum\": " , |
115 | seqnum, |
116 | ",\n\t\t\"world_size\": " , |
117 | world_size); |
118 | |
119 | if (inSize > 0 || outSize > 0) { |
120 | // for most collectives - append msg sizes, data type, device type |
121 | cur_trace_ = c10::str( |
122 | cur_trace_, |
123 | ",\n\t\t\"in_msg_size\": " , |
124 | inSize, |
125 | ",\n\t\t\"out_msg_size\": " , |
126 | outSize, |
127 | ",\n\t\t\"dtype\": \"" , |
128 | at::toString(dtype), |
129 | "\",\n\t\t\"devType\": \"" , |
130 | c10::DeviceTypeName(devType), |
131 | "\"" ); |
132 | } |
133 | if (curRoot_ != -1) { |
134 | // append root rank if applicable, e.g., broadcast, gather, scatter |
135 | cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": " , curRoot_); |
136 | } |
137 | if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) { |
138 | // append input and output splits if applicable, e.g., ALLTOALL_BASE |
139 | cur_trace_ = c10::str( |
140 | cur_trace_, |
141 | ",\n\t\t\"in_split\": [" , |
142 | c10::Join("," , curInSplitSizes_), |
143 | "]" |
144 | ",\n\t\t\"out_split\": [" , |
145 | c10::Join("," , curOutSplitSizes_), |
146 | "]" ); |
147 | } |
148 | comms_trace_.push_back(c10::str("\n\t{" , cur_trace_, "\n\t}" )); |
149 | |
150 | // record the trace to kineto trace if applicable |
151 | RECORD_PARAM_COMMS( |
152 | static_cast<int64_t>(seqnum), // seq |
153 | 0, // process group ptr |
154 | rank, |
155 | commName.c_str(), |
156 | inSize, |
157 | outSize, |
158 | dtype, |
159 | curInSplitSizes_, |
160 | curOutSplitSizes_); |
161 | |
162 | ++seqnum; |
163 | |
164 | // reset optional field |
165 | curRoot_ = -1; |
166 | curInSplitSizes_ = {}; |
167 | curOutSplitSizes_ = {}; |
168 | } |
169 | |
170 | } // namespace c10d |
171 | |
172 | #endif // USE_C10D_UCC |
173 | |