1 | #include <torch/csrc/lazy/core/ir_dump_util.h> |
2 | |
3 | #include <c10/util/Optional.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/lazy/backend/backend_interface.h> |
6 | #include <torch/csrc/lazy/backend/lowering_context.h> |
7 | #include <torch/csrc/lazy/core/ir_util.h> |
8 | |
9 | #include <regex> |
10 | #include <sstream> |
11 | #include <unordered_map> |
12 | |
13 | namespace torch { |
14 | namespace lazy { |
15 | namespace { |
16 | |
17 | using NodeIdMap = std::unordered_map<const Node*, size_t>; |
18 | |
19 | struct AttrTag { |
20 | std::string name; |
21 | std::string value; |
22 | std::string::size_type pos = 0; |
23 | }; |
24 | |
25 | std::string::size_type SkipTagSeparator( |
26 | const std::string& node_string, |
27 | std::string::size_type pos) { |
28 | return node_string.compare(pos, 2, ", " ) == 0 ? pos + 2 : pos; |
29 | } |
30 | |
31 | c10::optional<AttrTag> ParseAttrTag( |
32 | const std::string& node_string, |
33 | std::string::size_type pos) { |
34 | // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful |
35 | const std::regex tag_regex("^([a-zA-Z0-9_]+)=" ); |
36 | std::smatch match; |
37 | // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful |
38 | if (!std::regex_search( |
39 | node_string.begin() + pos, node_string.end(), match, tag_regex)) { |
40 | return c10::nullopt; |
41 | } |
42 | |
43 | std::string::size_type vpos = match[1].second - node_string.begin() + 1; |
44 | char nested_open = -1; |
45 | char nested_close = -1; |
46 | size_t nest_count = 1; |
47 | AttrTag tag; |
48 | tag.name = match[1].str(); |
49 | for (pos = vpos; pos < node_string.size(); ++pos) { |
50 | if (nested_open < 0) { |
51 | if (SkipTagSeparator(node_string, pos) != pos) { |
52 | break; |
53 | } |
54 | switch (node_string[pos]) { |
55 | case '(': |
56 | nested_open = node_string[pos]; |
57 | nested_close = ')'; |
58 | break; |
59 | case '[': |
60 | nested_open = node_string[pos]; |
61 | nested_close = ']'; |
62 | break; |
63 | case '{': |
64 | nested_open = node_string[pos]; |
65 | nested_close = '}'; |
66 | break; |
67 | } |
68 | } else if (node_string[pos] == nested_close) { |
69 | --nest_count; |
70 | if (nest_count == 0) { |
71 | nest_count = 1; |
72 | nested_open = nested_close = -1; |
73 | } |
74 | } else if (node_string[pos] == nested_open) { |
75 | ++nest_count; |
76 | } |
77 | } |
78 | tag.value = node_string.substr(vpos, pos - vpos); |
79 | tag.pos = pos; |
80 | return tag; |
81 | } |
82 | |
83 | NodeIdMap GenerateIdMap(c10::ArrayRef<const Node*> post_order) { |
84 | NodeIdMap id_map; |
85 | for (auto node : post_order) { |
86 | TORCH_CHECK(id_map.emplace(node, id_map.size()).second, node->ToString()); |
87 | } |
88 | return id_map; |
89 | } |
90 | |
91 | std::unordered_map<const Node*, size_t> GetRootsIds( |
92 | c10::ArrayRef<const Node*> roots) { |
93 | std::unordered_map<const Node*, size_t> roots_ids; |
94 | for (const auto i : c10::irange(roots.size())) { |
95 | roots_ids[roots[i]] = i; |
96 | } |
97 | return roots_ids; |
98 | } |
99 | |
100 | c10::optional<size_t> GetRootNodeId( |
101 | const Node* node, |
102 | const std::unordered_map<const Node*, size_t>& roots_ids) { |
103 | auto it = roots_ids.find(node); |
104 | if (it == roots_ids.end()) { |
105 | return c10::nullopt; |
106 | } |
107 | return it->second; |
108 | } |
109 | |
110 | std::vector<AttrTag> GetNodeTags(const Node* node) { |
111 | std::string node_string = node->ToString(); |
112 | std::string op_string = node->op().ToString(); |
113 | std::string::size_type pos = node_string.find(op_string); |
114 | TORCH_CHECK(pos != std::string::npos, node_string, " : " , op_string); |
115 | pos += op_string.size(); |
116 | std::vector<AttrTag> tags; |
117 | for (;;) { |
118 | pos = SkipTagSeparator(node_string, pos); |
119 | auto tag = ParseAttrTag(node_string, pos); |
120 | if (!tag) { |
121 | break; |
122 | } |
123 | pos = tag->pos; |
124 | tags.push_back(std::move(*tag)); |
125 | } |
126 | return tags; |
127 | } |
128 | |
129 | std::string GenerateDotNodeLabel( |
130 | const Node* node, |
131 | const std::unordered_map<const Node*, size_t>& roots_ids) { |
132 | static const size_t kMaxValueSize = 64; |
133 | std::stringstream ss; |
134 | ss << node->op() << "\\n" << node->shape(); |
135 | for (auto& tag : GetNodeTags(node)) { |
136 | ss << "\\n" << tag.name << "=" ; |
137 | if (tag.value.size() < kMaxValueSize) { |
138 | ss << tag.value; |
139 | } else { |
140 | ss << tag.value.substr(0, kMaxValueSize) << "..." ; |
141 | } |
142 | } |
143 | auto opt_root_id = GetRootNodeId(node, roots_ids); |
144 | if (opt_root_id) { |
145 | ss << "\\nROOT=" << *opt_root_id; |
146 | } |
147 | return ss.str(); |
148 | } |
149 | |
150 | std::string GenerateDotNodeSpec( |
151 | const Node* node, |
152 | const std::unordered_map<const Node*, size_t>& roots_ids) { |
153 | std::stringstream ss; |
154 | ss << "label=\"" << GenerateDotNodeLabel(node, roots_ids) << "\"" ; |
155 | return ss.str(); |
156 | } |
157 | |
158 | std::string GenerateTextNodeSpec(const Node* node, const NodeIdMap& id_map) { |
159 | std::stringstream ss; |
160 | ss << node->shapes() << " " << node->op() << "(" ; |
161 | size_t count = 0; |
162 | for (auto& output : node->operands()) { |
163 | if (count > 0) { |
164 | ss << ", " ; |
165 | } |
166 | ss << "%" << id_map.at(output.node); |
167 | if (output.node->num_outputs() > 1) { |
168 | ss << "." << output.index; |
169 | } |
170 | ++count; |
171 | } |
172 | ss << ")" ; |
173 | for (auto& tag : GetNodeTags(node)) { |
174 | ss << ", " << tag.name << "=" << tag.value; |
175 | } |
176 | return ss.str(); |
177 | } |
178 | |
179 | } // namespace |
180 | |
181 | std::string DumpUtil::ToDot(c10::ArrayRef<const Node*> nodes) { |
182 | auto post_order = Util::ComputePostOrder(nodes); |
183 | return PostOrderToDot(post_order, nodes); |
184 | } |
185 | |
186 | std::string DumpUtil::PostOrderToDot( |
187 | c10::ArrayRef<const Node*> post_order, |
188 | c10::ArrayRef<const Node*> roots) { |
189 | std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots); |
190 | NodeIdMap id_map = GenerateIdMap(post_order); |
191 | std::stringstream ss; |
192 | ss << "digraph G {\n" ; |
193 | for (auto node : post_order) { |
194 | ss << " node" << id_map.at(node) << " [" |
195 | << GenerateDotNodeSpec(node, roots_ids) << "]\n" ; |
196 | } |
197 | for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) { |
198 | const Node* node = *it; |
199 | size_t id = id_map.at(node); |
200 | for (const auto i : c10::irange(node->operands().size())) { |
201 | const Output& output = node->operand(i); |
202 | ss << " node" << id_map.at(output.node) << " -> node" << id; |
203 | if (node->operands().size() > 1) { |
204 | ss << " [label=\"i=" << i; |
205 | if (output.node->num_outputs() > 1) { |
206 | ss << ",o=" << output.index; |
207 | } |
208 | ss << "\"]\n" ; |
209 | } else { |
210 | if (output.node->num_outputs() > 1) { |
211 | ss << " [label=\"o=" << output.index << "\"]" ; |
212 | } |
213 | ss << "\n" ; |
214 | } |
215 | } |
216 | } |
217 | ss << "}\n" ; |
218 | return ss.str(); |
219 | } |
220 | |
221 | std::string DumpUtil::ToText(c10::ArrayRef<const Node*> nodes) { |
222 | auto post_order = Util::ComputePostOrder(nodes); |
223 | return PostOrderToText(post_order, nodes); |
224 | } |
225 | |
226 | std::string DumpUtil::PostOrderToText( |
227 | c10::ArrayRef<const Node*> post_order, |
228 | c10::ArrayRef<const Node*> roots) { |
229 | std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots); |
230 | NodeIdMap id_map = GenerateIdMap(post_order); |
231 | std::stringstream ss; |
232 | ss << "IR {\n" ; |
233 | for (auto node : post_order) { |
234 | auto opt_root_id = GetRootNodeId(node, roots_ids); |
235 | ss << " %" << id_map.at(node) << " = " |
236 | << GenerateTextNodeSpec(node, id_map); |
237 | if (opt_root_id) { |
238 | ss << ", ROOT=" << *opt_root_id; |
239 | } |
240 | ss << ", NodeType=" << typeid(*node).name(); |
241 | ss << "\n" ; |
242 | } |
243 | ss << "}\n" ; |
244 | return ss.str(); |
245 | } |
246 | |
247 | std::string DumpUtil::ToBackend( |
248 | c10::ArrayRef<Value> values, |
249 | const BackendDevice& device) { |
250 | auto lowering_ctx = LoweringContext::Create("IrToBackend" , device); |
251 | for (auto& ir_value : values) { |
252 | lowering_ctx->AddResult(ir_value); |
253 | } |
254 | auto computation = lowering_ctx->Build(); |
255 | return getBackend()->GetComputationBackendText(computation); |
256 | } |
257 | |
258 | } // namespace lazy |
259 | } // namespace torch |
260 | |