1#pragma once
2
3#include <string>
4#include <vector>
5#include <c10/macros/Macros.h>
6#include <c10/util/ThreadLocalDebugInfo.h>
7#include <ATen/record_function.h>
8#include <ATen/core/ivalue.h>
9
10namespace torch {
11
12extern TORCH_API const std::string kParamCommsCallName;
13
14class TORCH_API ParamCommsDebugInfo
15 : public c10::DebugInfoBase {
16
17 public:
18 ParamCommsDebugInfo() = default;
19 ParamCommsDebugInfo(
20 int rank,
21 std::string&& colName,
22 int inSize,
23 int outSize,
24 at::ScalarType dType,
25 std::vector<int64_t> inSplitSizes,
26 std::vector<int64_t> outSplitSizes);
27
28 ~ParamCommsDebugInfo() override = default;
29
30 int getRank() const {
31 return rank_;
32 }
33
34 const std::string getColumnName() const {
35 return columnName_;
36 }
37
38 int getInMessageSize() const {
39 return inMessageSize_;
40 }
41
42 int getOutMessageSize() const {
43 return outMessageSize_;
44 }
45
46 at::ScalarType getDType() const {
47 return dType_;
48 }
49
50 const std::vector<int64_t>& getInputSplitSizes() const {
51 return inputSplitSizes_;
52 }
53
54 const std::vector<int64_t>& getOutputSplitSizes() const {
55 return outputSplitSizes_;
56 }
57
58 private:
59 int rank_{};
60 std::string columnName_;
61 int inMessageSize_{};
62 int outMessageSize_{};
63 at::ScalarType dType_ = at::kByte;
64 std::vector<int64_t> inputSplitSizes_;
65 std::vector<int64_t> outputSplitSizes_;
66};
67
68#define RECORD_PARAM_COMMS( \
69 seq, \
70 pg_ptr, \
71 rank, \
72 colName, \
73 inSize, \
74 outSize, \
75 dType, \
76 inSplitSizes, \
77 outSplitSizes) \
78 auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
79 rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes); \
80 c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
81 std::initializer_list<const c10::IValue> paramList = { \
82 c10::IValue(seq), \
83 c10::IValue(pg_ptr), \
84 rank, \
85 colName, \
86 inSplitSizes, \
87 outSplitSizes}; \
88 c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
89 RECORD_FUNCTION(torch::kParamCommsCallName, paramInputs);
90
91#define RECORD_PARAM_COMMS_DATA( \
92 seq, \
93 pg_ptr, \
94 InputTensors, \
95 OutputTensors, \
96 rank, \
97 colName, \
98 inSize, \
99 outSize, \
100 dType, \
101 inSplitSizes, \
102 outSplitSizes) \
103 auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
104 rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes); \
105 c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
106 std::initializer_list<const c10::IValue> paramList = { \
107 c10::IValue(InputTensors), \
108 c10::IValue(seq), \
109 c10::IValue(pg_ptr), \
110 rank, \
111 colName, \
112 inSplitSizes, \
113 outSplitSizes}; \
114 c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
115 RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
116 torch::kParamCommsCallName, \
117 paramInputs, \
118 std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
119} // namespace torch
120