1 | #pragma once |
2 | |
3 | #include <string> |
4 | #include <vector> |
5 | #include <c10/macros/Macros.h> |
6 | #include <c10/util/ThreadLocalDebugInfo.h> |
7 | #include <ATen/record_function.h> |
8 | #include <ATen/core/ivalue.h> |
9 | |
10 | namespace torch { |
11 | |
12 | extern TORCH_API const std::string kParamCommsCallName; |
13 | |
14 | class TORCH_API ParamCommsDebugInfo |
15 | : public c10::DebugInfoBase { |
16 | |
17 | public: |
18 | ParamCommsDebugInfo() = default; |
19 | ParamCommsDebugInfo( |
20 | int rank, |
21 | std::string&& colName, |
22 | int inSize, |
23 | int outSize, |
24 | at::ScalarType dType, |
25 | std::vector<int64_t> inSplitSizes, |
26 | std::vector<int64_t> outSplitSizes); |
27 | |
28 | ~ParamCommsDebugInfo() override = default; |
29 | |
30 | int getRank() const { |
31 | return rank_; |
32 | } |
33 | |
34 | const std::string getColumnName() const { |
35 | return columnName_; |
36 | } |
37 | |
38 | int getInMessageSize() const { |
39 | return inMessageSize_; |
40 | } |
41 | |
42 | int getOutMessageSize() const { |
43 | return outMessageSize_; |
44 | } |
45 | |
46 | at::ScalarType getDType() const { |
47 | return dType_; |
48 | } |
49 | |
50 | const std::vector<int64_t>& getInputSplitSizes() const { |
51 | return inputSplitSizes_; |
52 | } |
53 | |
54 | const std::vector<int64_t>& getOutputSplitSizes() const { |
55 | return outputSplitSizes_; |
56 | } |
57 | |
58 | private: |
59 | int rank_{}; |
60 | std::string columnName_; |
61 | int inMessageSize_{}; |
62 | int outMessageSize_{}; |
63 | at::ScalarType dType_ = at::kByte; |
64 | std::vector<int64_t> inputSplitSizes_; |
65 | std::vector<int64_t> outputSplitSizes_; |
66 | }; |
67 | |
68 | #define RECORD_PARAM_COMMS( \ |
69 | seq, \ |
70 | pg_ptr, \ |
71 | rank, \ |
72 | colName, \ |
73 | inSize, \ |
74 | outSize, \ |
75 | dType, \ |
76 | inSplitSizes, \ |
77 | outSplitSizes) \ |
78 | auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \ |
79 | rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes); \ |
80 | c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ |
81 | std::initializer_list<const c10::IValue> paramList = { \ |
82 | c10::IValue(seq), \ |
83 | c10::IValue(pg_ptr), \ |
84 | rank, \ |
85 | colName, \ |
86 | inSplitSizes, \ |
87 | outSplitSizes}; \ |
88 | c10::ArrayRef<const c10::IValue> paramInputs(paramList); \ |
89 | RECORD_FUNCTION(torch::kParamCommsCallName, paramInputs); |
90 | |
91 | #define RECORD_PARAM_COMMS_DATA( \ |
92 | seq, \ |
93 | pg_ptr, \ |
94 | InputTensors, \ |
95 | OutputTensors, \ |
96 | rank, \ |
97 | colName, \ |
98 | inSize, \ |
99 | outSize, \ |
100 | dType, \ |
101 | inSplitSizes, \ |
102 | outSplitSizes) \ |
103 | auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \ |
104 | rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes); \ |
105 | c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ |
106 | std::initializer_list<const c10::IValue> paramList = { \ |
107 | c10::IValue(InputTensors), \ |
108 | c10::IValue(seq), \ |
109 | c10::IValue(pg_ptr), \ |
110 | rank, \ |
111 | colName, \ |
112 | inSplitSizes, \ |
113 | outSplitSizes}; \ |
114 | c10::ArrayRef<const c10::IValue> paramInputs(paramList); \ |
115 | RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \ |
116 | torch::kParamCommsCallName, \ |
117 | paramInputs, \ |
118 | std::vector<c10::IValue>(1, c10::IValue(OutputTensors))); |
119 | } // namespace torch |
120 | |