1#ifdef USE_C10D_UCC
2
3#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
4#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
5
6#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
7
8#include <sys/stat.h>
9#include <cstdlib>
10#include <ctime>
11#include <fstream>
12
13namespace c10d {
14
15void ProcessGroupUCCLogger::initCommsTracer() {
16 trace_generator = std::make_shared<CommTraceLogger>();
17 initialized_CommTraceLogger = true;
18}
19
20void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
21 if (!initialized_CommTraceLogger ||
22 trace_generator->getCommsTrace().empty()) {
23 return;
24 }
25
26 std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size);
27 time_t now_ = time(0);
28 std::tm* ltm = localtime(&now_);
29 if (ltm) {
30 dirname += c10::str(
31 "_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year));
32 }
33
34 std::string fullpath = "/tmp/" + dirname;
35 char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
36 if (user_path) {
37 fullpath = user_path;
38 }
39 std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
40 std::ofstream _outfile;
41 if (!_outfile.is_open()) {
42 if (!mkdir(fullpath.c_str(), 0777)) {
43 LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath;
44 } else if (errno != EEXIST) {
45 return;
46 }
47 _outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc);
48 }
49 // flush the traced comms
50 if (_outfile.is_open()) {
51 _outfile << "[" << c10::Join(",", trace_generator->getCommsTrace())
52 << "\n]";
53 _outfile.flush();
54 _outfile.close();
55 }
56}
57
58/* unused */
59void CommTraceLogger::setCurBlock(const std::string& name) {
60 curBlocks_.push_back(
61 c10::str("\"", name, "\"")); // add quote marks for JSON format
62}
63
64/* unused */
65void CommTraceLogger::popBlock() {
66 // TODO: remove specific name
67 curBlocks_.pop_back();
68}
69
70void CommTraceLogger::recordOptionalInfo(int root) {
71 curRoot_ = root;
72}
73
74void CommTraceLogger::recordOptionalInfo(
75 const std::vector<int64_t>& outputSplitSizes,
76 const std::vector<int64_t>& inputSplitSizes) {
77 curOutSplitSizes_ = outputSplitSizes;
78 curInSplitSizes_ = inputSplitSizes;
79}
80
81void CommTraceLogger::recordComms(
82 const std::string& commName,
83 const uintptr_t workReq,
84 const int rank,
85 const int world_size,
86 const std::vector<at::Tensor>& inputTensors,
87 const std::vector<at::Tensor>& outputTensors) {
88 auto inSize = (!inputTensors.empty()) ? inputTensors[0].numel() : 0;
89 auto outSize = (!outputTensors.empty()) ? outputTensors[0].numel() : 0;
90 auto dtype =
91 (!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte;
92 auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type()
93 : c10::DeviceType::CPU;
94 auto now = std::chrono::system_clock::now();
95 static auto startTS = now;
96 int64_t time_since_begin =
97 std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS)
98 .count();
99
100 // TODO: get markers from torch profiler if enabled
101
102 // common fields for all operations
103 std::string cur_trace_ = c10::str(
104 "\n\t\t\"markers\": [",
105 curBlocks_,
106 "]",
107 ",\n\t\t\"startTime_ns\": ",
108 time_since_begin,
109 ",\n\t\t\"comms\": \"",
110 commName,
111 "\"",
112 ",\n\t\t\"req\": ",
113 workReq,
114 ",\n\t\t\"seqnum\": ",
115 seqnum,
116 ",\n\t\t\"world_size\": ",
117 world_size);
118
119 if (inSize > 0 || outSize > 0) {
120 // for most collectives - append msg sizes, data type, device type
121 cur_trace_ = c10::str(
122 cur_trace_,
123 ",\n\t\t\"in_msg_size\": ",
124 inSize,
125 ",\n\t\t\"out_msg_size\": ",
126 outSize,
127 ",\n\t\t\"dtype\": \"",
128 at::toString(dtype),
129 "\",\n\t\t\"devType\": \"",
130 c10::DeviceTypeName(devType),
131 "\"");
132 }
133 if (curRoot_ != -1) {
134 // append root rank if applicable, e.g., broadcast, gather, scatter
135 cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_);
136 }
137 if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) {
138 // append input and output splits if applicable, e.g., ALLTOALL_BASE
139 cur_trace_ = c10::str(
140 cur_trace_,
141 ",\n\t\t\"in_split\": [",
142 c10::Join(",", curInSplitSizes_),
143 "]"
144 ",\n\t\t\"out_split\": [",
145 c10::Join(",", curOutSplitSizes_),
146 "]");
147 }
148 comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}"));
149
150 // record the trace to kineto trace if applicable
151 RECORD_PARAM_COMMS(
152 static_cast<int64_t>(seqnum), // seq
153 0, // process group ptr
154 rank,
155 commName.c_str(),
156 inSize,
157 outSize,
158 dtype,
159 curInSplitSizes_,
160 curOutSplitSizes_);
161
162 ++seqnum;
163
164 // reset optional field
165 curRoot_ = -1;
166 curInSplitSizes_ = {};
167 curOutSplitSizes_ = {};
168}
169
170} // namespace c10d
171
172#endif // USE_C10D_UCC
173