1 | // Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | // |
3 | // This source code is licensed under the BSD-style license found in the |
4 | // LICENSE file in the root directory of this source tree. |
5 | |
6 | #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp> |
7 | |
8 | namespace torch { |
9 | |
10 | extern const std::string kParamCommsCallName = "record_param_comms" ; |
11 | |
12 | ParamCommsDebugInfo::ParamCommsDebugInfo( |
13 | int rank, |
14 | std::string&& colName, |
15 | int inSize, |
16 | int outSize, |
17 | at::ScalarType dType, |
18 | std::vector<int64_t> inSplitSizes, |
19 | std::vector<int64_t> outSplitSizes) |
20 | : rank_(rank), |
21 | columnName_(colName), |
22 | inMessageSize_(inSize), |
23 | outMessageSize_(outSize), |
24 | dType_(dType), |
25 | inputSplitSizes_(std::move(inSplitSizes)), |
26 | outputSplitSizes_(std::move(outSplitSizes)) {} |
27 | |
28 | } // namespace torch |
29 | |